diff --git a/mttl/models/modifiers/sm_updater.py b/mttl/models/modifiers/sm_updater.py index d476b6951..ea9ffed25 100644 --- a/mttl/models/modifiers/sm_updater.py +++ b/mttl/models/modifiers/sm_updater.py @@ -11,9 +11,9 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier -from mttl.models.modifiers.sm_config import SparseMaskConfig -from mttl.models.modifiers.sparse_utils.sparse_linear import MaskedLinear, SparseLinear -from mttl.models.modifiers.sparse_utils.utils import ( +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( get_2d_indices_from_csr_matrix, get_top_k_sparcity, scipy_csr_to_torch_csr, diff --git a/mttl/models/modifiers/sparse_mask.py b/mttl/models/modifiers/sparse_mask.py index cca4973ce..82c77d8fd 100644 --- a/mttl/models/modifiers/sparse_mask.py +++ b/mttl/models/modifiers/sparse_mask.py @@ -3,28 +3,19 @@ from dataclasses import dataclass from typing import Union -import numpy as np -import torch -from scipy.sparse import csr_matrix from torch import nn from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig -from mttl.models.modifiers.sm_config import SparseMaskConfig -from mttl.models.modifiers.sm_updater import MaskUpdater -from mttl.models.modifiers.sparse_utils.sparse_linear import ( +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater +from mttl.models.modifiers.sparsity.sparse_linear import ( MaskedLinear, ScatteredSparseLinearModule, SparseLinear, SparseLinearConfig, ) -from mttl.models.modifiers.sparse_utils.utils import ( - get_2d_indices_from_csr_matrix, - get_top_k_sparcity, - scipy_csr_to_torch_csr, - torch_csr_to_scipy_csr, -) class SparseMaskAdapter(Modifier): @@ -55,7 +46,7 @@ def __init__( def forward(self, input): if self.maks_update_mode and self.training: - return self.mask_updater(self.sparse_layer, input) + return self.mask_updater(input) return self.sparse_layer(input) def prepare_for_mask_update(self): diff --git a/mttl/models/modifiers/sm_config.py b/mttl/models/modifiers/sparse_mask_config.py similarity index 89% rename from mttl/models/modifiers/sm_config.py rename to mttl/models/modifiers/sparse_mask_config.py index 0058b9828..4261392fd 100644 --- a/mttl/models/modifiers/sm_config.py +++ b/mttl/models/modifiers/sparse_mask_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from mttl.models.modifiers.sparse_utils.sparse_linear import SparseLinearConfig +from mttl.models.modifiers.sparsity.sparse_linear import SparseLinearConfig @dataclass diff --git a/mttl/models/modifiers/sparse_utils/profile_block_sparsity.py b/mttl/models/modifiers/sparse_utils/profile_block_sparsity.py deleted file mode 100644 index 58a08fc6b..000000000 --- a/mttl/models/modifiers/sparse_utils/profile_block_sparsity.py +++ /dev/null @@ -1,200 +0,0 @@ -# several options to compare for block sparce operations: -# 1. triton.ops.blocksparse -- this is supposed to work fast for cases whee the sparcity structure is not changing too often -# 2. stk -- https://github.com/stanford-futuredata/stk -- this is supposed to work fast for cases where the sparcity structure is changing often -import stk -import stk.ops -import torch -import torch.nn.functional as F -import triton as tn -from spops import csr_add, sddmm -from triton.ops.blocksparse import matmul - -from mttl.models.modifiers.sparse_mask import SparseMaskConfig, SparseWeights -from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights - -n_blocks = 4 -BLOCK_SIZE = 128 -dtype = torch.float16 - -sequence_length = 1024 -hidden_size = 2048 # phi2 size -mlp_size = 8192 - -sparcity = n_blocks * (BLOCK_SIZE**2) / (hidden_size * mlp_size) -print(f"sparsity: {sparcity}") - -# W = init_sparse_weights("block_sparse", 0.005, (K, N), BLOCK_SIZE).contiguous().to('cuda') -# X = torch.randn(M, K).to('cuda').contiguous() - - -def stk_sdd(X, W, topo): - return stk.ops.sdd(X, W, topo) - - -def torch_linear(X, W): - return F.linear(X, W) - - -def spops_sdd_structure_aware(X, W, topo: SparseWeights): - return sddmm(topo.row_offs, topo.row_idx, topo.col_idx, X, W) - - -def spops_sdd_sputnik(X, W, topo: SparseWeights): - return sddmm(topo.row_offs, topo.row_idx, topo.col_idx, X, W, backend="sputnik") - - -def torch_linear_w_sparse(X, W): - return F.linear(X, W) - - -def triton_blocksparse_mm(X, W, op): - return op(X, W) - - -def prepare_triton_bs_op(X, W): - Z, H = 1, 1 - AT = False - BT = False - op_mode = "sdd" - - def to_block_sparse_layout(matrix: torch.Tensor, block_size: int): - """ - Returns layout of block sparse matrix: i.e. a matrix of shape (M//block_size, N//block_size) where each element is a boolean indicating whether the block is non-zero. - """ - M, N = matrix.shape - assert M % block_size == 0, "M must be divisible by block_size" - assert N % block_size == 0, "N must be divisible by block_size" - matrix = matrix.reshape( - M // block_size, - block_size, - N // block_size, - block_size, - ).permute(0, 2, 1, 3) - matrix = matrix.flatten(2, 3).sum(dim=-1) - return matrix.cpu().bool().to(torch.int64) - - layout = to_block_sparse_layout(W, BLOCK_SIZE).unsqueeze(0) - # creat inputs - op = matmul(layout, BLOCK_SIZE, op_mode, trans_a=AT, trans_b=BT, device="cuda") - return op - - -# # adapted from https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html -@tn.testing.perf_report( - tn.testing.Benchmark( - # x_names=['o'], # Argument names to use as an x-axis for the plot. - # x_vals=[128*i for i in [8, 10, 20, 50, 64, 100]], # Different possible values for `x_name`. - x_names=["s"], # Argument names to use as an x-axis for the plot. - x_vals=[ - 128 * i for i in [8, 10, 12, 14, 16, 20] - ], # Different possible values for `x_name`. - x_log=False, # x axis is logarithmic. - line_arg="provider", # Argument name whose value corresponds to a different line in the plot. - line_vals=["naive", "stk", "triton_bs"], # Possible values for `line_arg`. - line_names=["Naive", "stk", "triton_bs"], # Label name for the lines. - styles=[ - ("blue", "-"), - ("green", "-"), - ("orange", "-"), - ("red", "-"), - ("purple", "-"), - ("black", "-"), - ], # Line color and style. - ylabel="ms", #'GB/s', # Label name for the y-axis. - xlabel="seq length dim.", - plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. - args={ - "h": hidden_size, - "o": mlp_size, - "sp": sparcity, - }, # Values for function arguments not in `x_names` and `y_name`. - ) -) -def benchmark(s, h, o, sp, provider): - X = torch.rand((s, h), device="cuda", dtype=dtype).contiguous() - W = ( - init_sparse_weights("block_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - W_row_sparse = ( - init_sparse_weights("row_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - WT = W.T - assert W.sum() > 0 - assert W_row_sparse.sum() == W.sum() - - quantiles = [0.5, 0.2, 0.8] - if provider == "naive": - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: torch_linear(X, WT), quantiles=quantiles - ) - if provider == "stk": - if BLOCK_SIZE != 128 or dtype != torch.float16: - ms, min_ms, max_ms = 0, 0, 0 - else: - W_stk = stk.ops.to_sparse(W, blocking=BLOCK_SIZE) - W_stk.validate() - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: stk_sdd(X, W, W_stk), quantiles=quantiles - ) - if provider == "torch_bsr": - W_bst = WT.to_sparse_bsr(blocksize=BLOCK_SIZE) - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: torch_linear(X, W_bst), quantiles=quantiles - ) - if provider == "spops_block": - W_spops_block = SparseWeights.from_dense( - W, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="block_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_structure_aware(X, W, W_spops_block), quantiles=quantiles - ) - if provider == "spops_row": - W_row_sparse = ( - init_sparse_weights("row_sparse", sp, (h, o), BLOCK_SIZE) - .to("cuda") - .to(dtype) - .contiguous() - ) - W_spops_row = SparseWeights.from_dense( - W_row_sparse, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="row_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_structure_aware(X, W_row_sparse, W_spops_row), - quantiles=quantiles, - ) - if provider == "spops_sputnik_block": - W_spops_block = SparseWeights.from_dense( - W, - SparseMaskConfig( - keep_ratio=sp, block_size=BLOCK_SIZE, sps_type="block_sparse" - ), - ).to("cuda") - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: spops_sdd_sputnik(X, W, W_spops_block), quantiles=quantiles - ) - if provider == "triton_bs": - op = prepare_triton_bs_op(X, W) - X = X[None, None, ...] - W = W[None, None, ...] - ms, min_ms, max_ms = tn.testing.do_bench( - lambda: triton_blocksparse_mm(X, W, op), quantiles=quantiles - ) - - gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) - # return gbps(ms), gbps(max_ms), gbps(min_ms) - return ms, max_ms, min_ms - - -benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mttl/models/modifiers/sparse_utils/profile_sparse_mask.py b/mttl/models/modifiers/sparse_utils/profile_sparse_mask.py deleted file mode 100644 index 39caa85a7..000000000 --- a/mttl/models/modifiers/sparse_utils/profile_sparse_mask.py +++ /dev/null @@ -1,329 +0,0 @@ -import logging -import re -import time - -import numpy as np -import pandas as pd -import torch -from pytorch_lightning import seed_everything - -from mttl.logging import logger -from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.lora import LoRAConfig -from mttl.models.modifiers.sparse_mask import ( - MaskedLinear, - ScatteredSparseLinearModule, - SparseLinearModule, - SparseMaskConfig, -) -from mttl.models.utils import model_loader_helper, transfer_batch_to_device - -logger.setLevel(logging.ERROR) -model_name = "EleutherAI/gpt-neo-125m" # "EleutherAI/gpt-neo-125m" # "phi-2" -block_size = 128 -n_blocks = 6 -mask_updater = None -modify_layers = ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*Wqkv.*" # -n_iters = 50 - -# input sizes and batch sizes for testing -max_seq_len = 1024 -bs = 1 -vocab_size = 32000 - - -def calculate_lora_parameters(input_dim, output_dim, rank): - return input_dim * rank + output_dim * rank - - -def find_hyperpaams(): - - model = model_loader_helper( - model_name, - bf16=True, - fp16=False, - load_in_4bit=False, - load_in_8bit=False, - device_map="cpu", - ) - modules = dict(model.named_modules()) - modified_modules = {} - keep_ratios = [] - lora_ranks = [] - - for ml in modify_layers.split("|"): - for name, module in modules.items(): - if re.match(ml, name) and ml not in modified_modules: - keep_ratio = ( - n_blocks - * (block_size**2) - / (module.in_features * module.out_features) - ) - tot_sparse_params = ( - module.in_features * module.out_features * keep_ratio - ) - lora_rank = 1 - for rank in range(1, module.in_features): - lora_params = calculate_lora_parameters( - module.in_features, module.out_features, rank - ) - if lora_params <= tot_sparse_params: - lora_rank = rank - else: - break - modified_modules[ml] = { - "module": module, - "keep_ratio": keep_ratio, - "lora_rank": lora_rank, - } - keep_ratios.append(keep_ratio) - lora_ranks.append(lora_rank) - return np.mean(keep_ratios), int(np.mean(lora_ranks)) - - -keep_ratio, lora_rank = find_hyperpaams() -print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") - - -table = pd.DataFrame( - columns=[ - "Av. Runtime", - "Av. Forward time", - "Av. Backward time", - "Allocated Memory", - "Reserved Memory", - "Number of Parameters", - ] -) - - -def dummy_batch(): - torch.manual_seed(0) - batch = { - "input_ids": torch.randint(10, vocab_size, (bs, max_seq_len)), - "labels": torch.randint(10, vocab_size, (bs, max_seq_len)), - } - seq_len = torch.randint(0, max_seq_len, (bs,)) - attn_mask = torch.zeros(bs, max_seq_len, dtype=torch.int32) - attn_mask[torch.arange(bs), seq_len] = 1 - attn_mask = 1 - attn_mask.cumsum(dim=-1) - batch["attention_mask"] = attn_mask - return batch - - -def benchmark_module(module, runs=100): - # Set up inputs - input_data = dummy_batch() - input_data = transfer_batch_to_device(input_data, "cuda") - - # Warm-up to ensure accurate measurement - for _ in range(10): - loss = module(**input_data).loss - loss.backward() - module.zero_grad() - - forward_time_total = 0.0 - backward_time_total = 0.0 - - # Benchmark runs - for _ in range(runs): - # Forward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss = module(**input_data).loss - torch.cuda.synchronize() - forward_time = time.time() - start_time - - # Backward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss.backward() - torch.cuda.synchronize() - backward_time = time.time() - start_time - - # Zero gradients - module.zero_grad() - - # Accumulate times - forward_time_total += forward_time - backward_time_total += backward_time - - avg_forward_time = forward_time_total / runs - avg_backward_time = backward_time_total / runs - avg_runtime = avg_forward_time + avg_backward_time - - # Measure memory usage - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - loss = module(**input_data).loss # Forward pass to record memory - loss.backward() # Backward pass to record memory - memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.max_memory_reserved() - - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - return ( - avg_runtime, - avg_forward_time, - avg_backward_time, - memory_allocated, - memory_reserved, - ) - - -def run_benchmark(name, adapter_config): - seed_everything(0) - model = model_loader_helper( - model_name, - bf16=True, - fp16=False, - load_in_4bit=False, - load_in_8bit=False, - device_map="cpu", - ) - modify_transformer(model, adapter_config) - model.to("cuda") - n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - - runtime, forward_time, backward_time, sparse_alloc, sparse_reserved = ( - benchmark_module(model, runs=n_iters) - ) - print( - f"{name} - Runtime: {runtime:.6f}s, Allocated Memory: {sparse_alloc / 1e6:.2f}MB, Reserved Memory: {sparse_reserved / 1e6:.2f}MB" - ) - table.loc[name] = [ - runtime, - forward_time, - backward_time, - sparse_alloc, - sparse_reserved, - n_params, - ] - - -############################################################################################################################################################ -# Benchmarking LoRA - -adapter_config = LoRAConfig(modify_layers=modify_layers, lora_rank=lora_rank) -run_benchmark("LoRA", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule", adapter_config) - - -################################################################################################################################################################# -# Benchmarking SparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="sp_add+sp_mm", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - # mask_updater=mask_updater, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("SparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking SparseLinearModule with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="sp_add+sp_mm", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("SparseLinearModule (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking SPiEL with regular sparsity kernel - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="spiel", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("Spiel Linear (reg. sp)", adapter_config) - - -################################################################################################################################################################# -# Benchmarking MaskedLinear with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (reg. sp)", adapter_config) - -############################################################################################################################################################ -# Benchmarking MaskedLinear with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking ScatteredSparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (block sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking ScatteredSparseLinearModule with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -print(table) -# write table to a csv file -table.to_csv("benchmark_results.csv") diff --git a/mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py b/mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py deleted file mode 100644 index cb91c3be3..000000000 --- a/mttl/models/modifiers/sparse_utils/profile_sparse_mask_only_linear.py +++ /dev/null @@ -1,352 +0,0 @@ -import logging -import re -import time - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from pytorch_lightning import seed_everything - -from mttl.logging import logger -from mttl.models.modifiers import modify_transformer -from mttl.models.modifiers.lora import LoRA, LoRAConfig -from mttl.models.modifiers.sparse_mask import ( - MaskedLinear, - ScatteredSparseLinearModule, - SparseLinearModule, - SparseMaskAdapter, - SparseMaskConfig, -) -from mttl.models.utils import model_loader_helper, transfer_batch_to_device - -logger.setLevel(logging.ERROR) -model_name = "EleutherAI/gpt-neo-125m" # "EleutherAI/gpt-neo-125m" # "phi-2" -block_size = 64 -n_blocks = 128 -mask_updater = None -modify_layers = ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*q_proj.*|.*v_proj.*|.*k_proj.*" # ".*Wqkv.*" # -n_iters = 50 - -in_d = 2048 -out_d = 8192 * 2 -dtype = torch.bfloat16 - -# input sizes and batch sizes for testing -max_seq_len = 1024 -bs = 5 -vocab_size = 32000 - - -def calculate_lora_parameters(input_dim, output_dim, rank): - return input_dim * rank + output_dim * rank - - -layer = nn.Linear(in_d, out_d) -layer.weight.requires_grad_(False) -layer.bias.requires_grad_(False) - - -def find_hyperpaams(): - modules = {"linear": layer} - modified_modules = {} - keep_ratios = [] - lora_ranks = [] - - for name, module in modules.items(): - keep_ratio = ( - n_blocks * (block_size**2) / (module.in_features * module.out_features) - ) - tot_sparse_params = module.in_features * module.out_features * keep_ratio - lora_rank = 1 - for rank in range(1, module.in_features): - lora_params = calculate_lora_parameters( - module.in_features, module.out_features, rank - ) - if lora_params <= tot_sparse_params: - lora_rank = rank - else: - break - modified_modules[name] = { - "module": module, - "keep_ratio": keep_ratio, - "lora_rank": lora_rank, - } - keep_ratios.append(keep_ratio) - lora_ranks.append(lora_rank) - return np.mean(keep_ratios), int(np.mean(lora_ranks)) - - -keep_ratio, lora_rank = find_hyperpaams() -print(f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}") - - -table = pd.DataFrame( - columns=[ - "Av. Runtime", - "Av. Forward time", - "Av. Backward time", - "Allocated Memory", - "Reserved Memory", - "Number of Parameters", - ] -) - - -def dummy_batch(): - torch.manual_seed(0) - batch = { - "input_ids": torch.randint(10, vocab_size, (bs, max_seq_len)), - "labels": torch.randint(10, vocab_size, (bs, max_seq_len)), - } - seq_len = torch.randint(0, max_seq_len, (bs,)) - attn_mask = torch.zeros(bs, max_seq_len, dtype=torch.int32) - attn_mask[torch.arange(bs), seq_len] = 1 - attn_mask = 1 - attn_mask.cumsum(dim=-1) - batch["attention_mask"] = attn_mask - return batch - - -def benchmark_module(module, runs=100): - # Set up inputs - input_data = dummy_batch() - input_data = torch.rand(bs, max_seq_len, in_d).to("cuda").to(dtype) - - # Warm-up to ensure accurate measurement - for _ in range(10): - out = module(input_data) - loss = torch.mean(out) - loss.backward() - module.zero_grad() - - forward_time_total = 0.0 - backward_time_total = 0.0 - - # Benchmark runs - for _ in range(runs): - # Forward pass timing - torch.cuda.synchronize() - start_time = time.time() - out = module(input_data) - loss = torch.mean(out) - torch.cuda.synchronize() - forward_time = time.time() - start_time - - # Backward pass timing - torch.cuda.synchronize() - start_time = time.time() - loss.backward() - torch.cuda.synchronize() - backward_time = time.time() - start_time - - # Zero gradients - module.zero_grad() - - # Accumulate times - forward_time_total += forward_time - backward_time_total += backward_time - - avg_forward_time = forward_time_total / runs - avg_backward_time = backward_time_total / runs - avg_runtime = avg_forward_time + avg_backward_time - - # Measure memory usage - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - out = module(input_data) # Forward pass to record memory - loss = torch.mean(out) - loss.backward() # Backward pass to record memory - memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.max_memory_reserved() - - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - return ( - avg_runtime, - avg_forward_time, - avg_backward_time, - memory_allocated, - memory_reserved, - ) - - -def run_benchmark(name, adapter_config): - seed_everything(0) - if isinstance(adapter_config, LoRAConfig): - module = LoRA(adapter_config, layer) - else: - module = SparseMaskAdapter(adapter_config, layer) - - module.to("cuda").to(dtype) - n_params = sum(p.numel() for p in module.parameters() if p.requires_grad) - - runtime, forward_time, backward_time, sparse_alloc, sparse_reserved = ( - benchmark_module(module, runs=n_iters) - ) - print( - f"{name} - Runtime: {runtime:.6f}s, Allocated Memory: {sparse_alloc / 1e6:.2f}MB, Reserved Memory: {sparse_reserved / 1e6:.2f}MB, Number of Parameters: {n_params}" - ) - table.loc[name] = [ - runtime, - forward_time, - backward_time, - sparse_alloc, - sparse_reserved, - n_params, - ] - - -############################################################################################################################################################ -# Benchmarking LoRA - -adapter_config = LoRAConfig(modify_layers=modify_layers, lora_rank=lora_rank) -run_benchmark("LoRA", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule + Dense without spoops - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="dense+triton_block_sparse", -# sps_type="block_sparse", -# keep_ratio=keep_ratio, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("BlockSparseLinearModule + Dense", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule without spoops - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse_scatter", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule (scatter add)", adapter_config) - -################################################################################################################################################################# -# Benchmarking BlcockSparseLinearModule - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="triton_block_sparse", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("BlockSparseLinearModule", adapter_config) - - -################################################################################################################################################################# -# Benchmarking SparseLinearModule - - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="sp_add+sp_mm", -# sps_type="regular_sparse", -# keep_ratio=keep_ratio, -# # mask_updater=mask_updater, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("SparseLinearModule (reg sp.)", adapter_config) - -# ############################################################################################################################################################ -# # Benchmarking SparseLinearModule with block sparsity - - -# adapter_config = SparseMaskConfig( -# modify_layers=modify_layers, -# sps_impl="sp_add+sp_mm", -# sps_type="block_sparse", -# keep_ratio=keep_ratio, -# reselection_steps=1, -# block_size=block_size, -# ) -# run_benchmark("SparseLinearModule (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking SPiEL with regular sparsity kernel - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="spiel", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("Spiel Linear (reg. sp)", adapter_config) - - -################################################################################################################################################################# -# Benchmarking MaskedLinear with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (reg. sp)", adapter_config) - -############################################################################################################################################################ -# Benchmarking MaskedLinear with block sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="masked_linear", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("MaskedLinear (block sp.)", adapter_config) - -################################################################################################################################################################# -# Benchmarking ScatteredSparseLinearModule - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="block_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (block sp.)", adapter_config) - -############################################################################################################################################################ -# Benchmarking ScatteredSparseLinearModule with regular sparsity - - -adapter_config = SparseMaskConfig( - modify_layers=modify_layers, - sps_impl="scattered", - sps_type="regular_sparse", - keep_ratio=keep_ratio, - reselection_steps=1, - block_size=block_size, -) -run_benchmark("ScatteredSparseLinearModule (reg sp.)", adapter_config) - -############################################################################################################################################################ -# orer table by Av. Runtime -table = table.sort_values("Av. Runtime") -print(table) -# write table to a csv file -table.to_csv("benchmark_results.csv") diff --git a/mttl/models/modifiers/sparsity/__init__.py b/mttl/models/modifiers/sparsity/__init__.py new file mode 100644 index 000000000..f0a817b96 --- /dev/null +++ b/mttl/models/modifiers/sparsity/__init__.py @@ -0,0 +1 @@ +from mttl.models.modifiers.sparse_mask import * diff --git a/mttl/models/modifiers/sparsity/mask_updater.py b/mttl/models/modifiers/sparsity/mask_updater.py new file mode 100644 index 000000000..0d2908ed2 --- /dev/null +++ b/mttl/models/modifiers/sparsity/mask_updater.py @@ -0,0 +1,158 @@ +import torch +from scipy.sparse import csr_matrix +from torch import nn + +from mttl.logging import logger +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( + get_2d_indices_from_csr_matrix, + get_top_k_sparcity, + scipy_csr_to_torch_csr, + torch_csr_to_scipy_csr, +) +from mttl.registrable import Registrable + + +class MaskUpdater(nn.Module, Registrable): + def __init__(self, config: SparseMaskConfig): + super().__init__() + self.config = config + + +@MaskUpdater.register("snip", config_cls=SparseMaskConfig) +class SNIPMaskUpdater(MaskUpdater): + """ + It is used to periodically re-calculate the sparse mask indices a la SNIP (https://arxiv.org/pdf/1810.02340). + To recalculate the mask, it uses ONE infoming batch to estimate the importance of each parameter. + + It accumulates learned weights in a dense CPU matrix. For MaskedLinear implementation this accumulation is already done in MaskedLinear class, since sparse mask is kept in dense format. + This accumulation is useful e.g. to make sure that the weights that have been learned in the past and are selected again are not reinitialized to 0. + """ + + def __init__( + self, config: SparseMaskConfig, base_weights_shape, base_weights_shape_dtype + ): + super().__init__(config) + + self.keep_ratio = config.keep_ratio + self.block_size = config.block_size + + self._n_mask_updates = 0 + self.updating_the_mask = False + + self.binary_mask = None + self._selected_indices = None + self._backward_hooks = [] + self.sparse_layer_weights, self.sparse_layer_biases = None, None + + # sparse weights for accumulation on CPU + self.accumulated_sparse_weights = torch.zeros( + base_weights_shape, device="cpu", dtype=base_weights_shape_dtype + ) + + def switch_to_mask_update_mode(self, sparse_layer: SparseLinear): + self.updating_the_mask = True + self._selected_indices = None + base_weights, base_biases, sparse_weights, sparse_biases = ( + sparse_layer.get_weights_for_mask_learning() + ) + if isinstance(sparse_layer, MaskedLinear): + # here we already keep sparse weights as dense matrix, so accumulation in SNIP is not needed + self.sparse_layer_weights = base_weights + sparse_weights + else: + assert isinstance(sparse_weights, csr_matrix) + # need to do two things: + # 1. keep track of accumulated sparse weights + # 2. Merge those accumulated weight deltas into the base weights and use them for importance estimation + r, c = get_2d_indices_from_csr_matrix(sparse_weights) + if len(r) > 0: + self.accumulated_sparse_weights[r, c] = torch.tensor( + sparse_weights[r, c], + dtype=self.accumulated_sparse_weights.dtype, + device="cpu", + ) + self.sparse_layer_weights = ( + base_weights + self.accumulated_sparse_weights.to(base_weights.device) + ) + + self.sparse_layer_biases = base_biases + if sparse_biases is not None: + if self.sparse_layer_biases is None: + self.sparse_layer_biases = sparse_biases.detach() + else: + self.sparse_layer_biases += sparse_biases.detach() + + self.binary_mask = torch.ones_like( + self.sparse_layer_weights, device=self.sparse_layer_weights.device + ) + self.binary_mask.requires_grad = True + + def mask_backward_hook(mask): + selected_params_dense = get_top_k_sparcity( + mask.grad, self.config.sps_type, self.keep_ratio, self.block_size + ) + selected_params = selected_params_dense.float().to_sparse_csr() # .cpu() + if self._selected_indices == None: + self._selected_indices = selected_params # .coalesce() + else: + self._selected_indices += selected_params + self._selected_indices = self._selected_indices # .coalesce() + + mask.grad = None # be efficient, throw aways the grads + return None + + hook_handle = self.binary_mask.register_post_accumulate_grad_hook( + mask_backward_hook + ) + self._backward_hooks.append(hook_handle) + + def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): + self.unregister_hooks() + self.updating_the_mask = False + self.sparse_layer_weights, self.sparse_layer_biases = None, None + # update the mask of the sparse layer + # SNIP weight accumulation: we set the newly selected weights to zeros, + # but weights that have been already learned in the past are kept + if isinstance(sparse_layer, MaskedLinear): + new_weights = self.selected_indices + else: + # other sparse layers than MaskedLinear, do not accumulate weights + # so its handeled here + new_weights = self.selected_indices + new_weights = torch_csr_to_scipy_csr(new_weights) + r, c = get_2d_indices_from_csr_matrix(new_weights) + new_weights *= 0.0 + new_weights[r, c] = self.accumulated_sparse_weights[r, c].float() + new_weights = scipy_csr_to_torch_csr(new_weights) + + sparse_layer.reset_sparse_weights(new_weights) + self._selected_indices = None + self.binary_mask = None + self._n_mask_updates += 1 + + @property + def selected_indices(self) -> torch.Tensor: + if self.config.steps_in_mask_selection == 1: + return self._selected_indices + raise NotImplementedError( + "More than one step in mask selection is not supported" + ) + + def forward(self, x: torch.Tensor): + input_dtype = x.dtype + x = x.to(self.sparse_layer_weights.dtype) + bias = ( + self.sparse_layer_biases.detach().to(self.sparse_layer_weights.dtype) + if self.sparse_layer_biases is not None + else None + ) + assert self.sparse_layer_weights is not None + return torch.nn.functional.linear( + x, self.sparse_layer_weights.detach() * self.binary_mask, bias + ).to(input_dtype) + + def unregister_hooks(self): + for hook in self._backward_hooks: + hook.remove() + self._backward_hooks = [] diff --git a/mttl/models/modifiers/sparsity/sm_updater.py b/mttl/models/modifiers/sparsity/sm_updater.py new file mode 100644 index 000000000..f13bb25a3 --- /dev/null +++ b/mttl/models/modifiers/sparsity/sm_updater.py @@ -0,0 +1,196 @@ +from abc import ABC, abstractmethod +from collections import namedtuple +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from scipy.sparse import csr_matrix +from torch import nn +from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut + +from mttl.logging import logger +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparsity.sparse_linear import MaskedLinear, SparseLinear +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( + get_2d_indices_from_csr_matrix, + get_top_k_sparcity, + scipy_csr_to_torch_csr, + torch_csr_to_scipy_csr, +) +from mttl.registrable import Registrable + + +class MaskUpdater(nn.Module, Registrable): + def __init__(self, config: SparseMaskConfig): + super().__init__() + self.config = config + + +@MaskUpdater.register("snip", config_cls=SparseMaskConfig) +class SNIPMaskUpdater(MaskUpdater): + """ + It is used to periodically re-calculate the sparse mask indices a la SNIP (https://arxiv.org/pdf/1810.02340). + To recalculate the mask, it uses a couple of incoming mini-batches to estimate the importance of each parameter. + + It accumulates learned weights in a dense CPU matrix. + This is useful e.g. to make sure that the weights that have been learned in the past and are selected again are not reinitialized to 0. + """ + + def __init__( + self, config: SparseMaskConfig, base_weights_shape, base_weights_shape_dtype + ): + super().__init__(config) + + self.keep_ratio = config.keep_ratio + self.block_size = config.block_size + + self._steps_since_last_mask_update = int(config.skip_zeros_mask_update) + self._mask_update_steps = 0 + self._n_mask_updates = 0 + + self.updating_the_mask = False + + self.binary_mask = None + self._selected_indices = None + self._backward_hooks = [] + self.sparse_layer_weights, self.sparse_layer_biases = None, None + + # sparse weights for accumulation on CPU + self.accumulated_sparse_weights = torch.zeros( + base_weights_shape, device="cpu", dtype=base_weights_shape_dtype + ) + + def switch_to_mask_update_mode(self, sparse_layer): + self.updating_the_mask = True + self._selected_indices = None + base_weights, base_biases, sparse_weights, sparse_biases = ( + sparse_layer.get_weights_for_mask_learning() + ) + if isinstance(sparse_layer, MaskedLinear): + # here we already keep sparse weights as dense matrix, so accumulation in SNIP is not needed + self.sparse_layer_weights = base_weights + sparse_weights + else: + assert isinstance(sparse_weights, csr_matrix) + # need to do two things: + # 1. keep track of accumulated sparse weights + # 2. Merge those accumulated weight deltas into the base weights and use them for importance estimation + r, c = get_2d_indices_from_csr_matrix(sparse_weights) + if len(r) > 0: + self.accumulated_sparse_weights[r, c] = torch.tensor( + sparse_weights[r, c], + dtype=self.accumulated_sparse_weights.dtype, + device="cpu", + ) + self.sparse_layer_weights = ( + base_weights + self.accumulated_sparse_weights.to(base_weights.device) + ) + + self.sparse_layer_biases = base_biases + if sparse_biases is not None: + if self.sparse_layer_biases is None: + self.sparse_layer_biases = sparse_biases.detach() + else: + self.sparse_layer_biases += sparse_biases.detach() + + self.binary_mask = torch.ones_like( + self.sparse_layer_weights, device=self.sparse_layer_weights.device + ) + self.binary_mask.requires_grad = True + + def mask_backward_hook(mask): + selected_params_dense = get_top_k_sparcity( + mask.grad, self.config.sps_type, self.keep_ratio, self.block_size + ) + selected_params = selected_params_dense.float().to_sparse_csr() # .cpu() + if self._selected_indices == None: + self._selected_indices = selected_params # .coalesce() + else: + self._selected_indices += selected_params + self._selected_indices = self._selected_indices # .coalesce() + + mask.grad = None # be efficient, throw aways the grads + return None + + hook_handle = self.binary_mask.register_post_accumulate_grad_hook( + mask_backward_hook + ) + self._backward_hooks.append(hook_handle) + + def switch_to_weights_update_mode(self, sparse_layer: SparseLinear): + self.unregister_hooks() + self.updating_the_mask = False + self.sparse_layer_weights, self.sparse_layer_biases = None, None + # update the mask of the sparse layer + # SNIP weight accumulation: we set the newly selected weights to zeros, + # but weights that have been already learned in the past are kept + if isinstance(sparse_layer, MaskedLinear): + new_weights = self.selected_indices + else: + # other sparse layers than MaskedLinear, do not accumulate weights + # so its handeled here + new_weights = self.selected_indices + new_weights = torch_csr_to_scipy_csr(new_weights) + r, c = get_2d_indices_from_csr_matrix(new_weights) + new_weights *= 0.0 + new_weights[r, c] = self.accumulated_sparse_weights[r, c].float() + new_weights = scipy_csr_to_torch_csr(new_weights) + + sparse_layer.reset_sparse_weights(new_weights) + self._selected_indices = None + self.binary_mask = None + self._n_mask_updates += 1 + + @property + def selected_indices(self) -> torch.Tensor: + if self.config.steps_in_mask_selection == 1: + return self._selected_indices + raise NotImplementedError( + "More than one step in mask selection is not supported" + ) + + def prepare_mask_or_weights_learning(self, sparse_layer: SparseLinear): + """ + Currently we have two regimes that we alternate: + - mask learning: update the non-zero indices + - weight learning: update the sparse weights + + Here we figure out what regume we are in. + """ + if self._time_to_update_mask(sparse_layer) and not self.updating_the_mask: + self.switch_to_mask_update_mode(sparse_layer) + self._mask_update_steps += 1 + + elif self.updating_the_mask and not self._time_to_update_sparse_weights( + sparse_layer + ): + self._mask_update_steps += 1 + + elif self.updating_the_mask and self._time_to_update_sparse_weights( + sparse_layer + ): + self.switch_to_weights_update_mode(sparse_layer) + self._mask_update_steps = 0 + self._steps_since_last_mask_update = 0 + + if not self.updating_the_mask: + self._steps_since_last_mask_update += 1 + + def forward(self, sparse_layer: SparseLinear, x: torch.Tensor): + input_dtype = x.dtype + x = x.to(self.sparse_layer_weights.dtype) + bias = ( + self.sparse_layer_biases.detach().to(self.sparse_layer_weights.dtype) + if self.sparse_layer_biases is not None + else None + ) + assert self.sparse_layer_weights is not None + return torch.nn.functional.linear( + x, self.sparse_layer_weights.detach() * self.binary_mask, bias + ).to(input_dtype) + + def unregister_hooks(self): + for hook in self._backward_hooks: + hook.remove() + self._backward_hooks = [] diff --git a/mttl/models/modifiers/sparse_utils/sparse_linear.py b/mttl/models/modifiers/sparsity/sparse_linear.py similarity index 93% rename from mttl/models/modifiers/sparse_utils/sparse_linear.py rename to mttl/models/modifiers/sparsity/sparse_linear.py index 4ce38854f..6c74d8d44 100644 --- a/mttl/models/modifiers/sparse_utils/sparse_linear.py +++ b/mttl/models/modifiers/sparsity/sparse_linear.py @@ -11,10 +11,9 @@ from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig -from mttl.models.modifiers.sparse_utils.utils import ( +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( BlcokSparseLinearFunction_SP_ADD, BlcokSparseLinearFunction_SP_SCATTER, - LinearWithSparseDelta, SparseLinearFunction_SP_ADD, _scatter_add_flattened, get_2d_indices_from_csr_matrix, @@ -361,6 +360,12 @@ def scipy_representation(self): return csr_matrix((data, (row_idx, col_idx)), shape=self.base_weight.shape) +############# +# MaskedLinear keeps sparse weights in the dense format. THis has the advantage that we do not neet to fumble with the optimizer. +# Class below try implementing sparse layer in a memory efficient way, similar to to SpIEL (https://arxiv.org/pdf/2401.16405), which uses essentially uses ScatteredSparseLinearModule. +# Using below classes may require additional tricks like in the SpIEL paper. + + class SparseLinearModule(SparseWeights, SparseLinear): """ Implements a sparse linear layer with sparse weights and sparse backprop. @@ -588,53 +593,3 @@ def reset_sparse_weights(self, mask: torch.Tensor): dtype=torch.int64, device=self.base_weight.device, ) - - -class SpieLSparseLinearModule(SparseLinearModule): - """ - This implements the SpIEL kernel: https://arxiv.org/pdf/2401.16405 - """ - - def __init__( - self, - weight, - bias, - config: SparseLinearConfig, - parent_name=None, - mask: torch.Tensor = None, - ): - super().__init__( - weight, - bias, - config, - parent_name, - sparse_func=LinearWithSparseDelta, - ) - indices = torch.tensor( - np.array(self.oneD_indices), - dtype=torch.int64, - device=self.base_weight.device, - ) - self.register_buffer("idxs", indices) - - @property - def oneD_indices(self): - """ - Returns a simple 1d representation of the sparse weights instead of the CSR format. - """ - twoD_indices = self.twoD_indices - return twoD_indices[0] * self.shape[1] + twoD_indices[1] - - def forward(self, input): - bias = self.base_bias - if bias and self.sparse_bias: - bias = self.base_bias + self.sparse_bias - return self.sparse_func.apply( - input, - self.base_weight, - self.sparse_weights, - self.idxs, - bias, - None, - self.base_weight.dtype, - ) diff --git a/mttl/models/modifiers/sparsity/sparse_mask.py b/mttl/models/modifiers/sparsity/sparse_mask.py new file mode 100644 index 000000000..a94505583 --- /dev/null +++ b/mttl/models/modifiers/sparsity/sparse_mask.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +from collections import namedtuple +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from scipy.sparse import csr_matrix +from torch import nn +from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut + +from mttl.logging import logger +from mttl.models.modifiers.base import Modifier, ModifierConfig +from mttl.models.modifiers.sparse_mask_config import SparseMaskConfig +from mttl.models.modifiers.sparse_utils.sparse_linear import ( + MaskedLinear, + ScatteredSparseLinearModule, + SparseLinear, + SparseLinearConfig, +) +from mttl.models.modifiers.sparse_utils.utils import ( + get_2d_indices_from_csr_matrix, + get_top_k_sparcity, + scipy_csr_to_torch_csr, + torch_csr_to_scipy_csr, +) +from mttl.models.modifiers.sparsity.mask_updater import MaskUpdater + + +class SparseMaskAdapter(Modifier): + def __init__( + self, + config: SparseMaskConfig, + layer: nn.Module, + **kwargs, + ): + self.name = kwargs.get("layer_name", None) + super().__init__() + self.config = config + + self.dense_layer_weight = layer.weight + self.dense_layer_bias = layer.bias + self.sps_type = config.sps_type + self.sparse_layer: SparseLinear = None + self.mask_updater: MaskUpdater = None + if not self.config.mask_updater is None: + self.mask_updater: MaskUpdater = MaskUpdater.get_class_by_name( + config.mask_updater, + )( + self.config, + base_weights_shape=self.dense_layer_weight.shape, + base_weights_shape_dtype=self.dense_layer_weight.dtype, + ) + self.maks_update_mode = False + + def forward(self, input): + if self.maks_update_mode and self.training: + return self.mask_updater(self.sparse_layer, input) + return self.sparse_layer(input) + + def prepare_for_mask_update(self): + if self.mask_updater is not None: + self.mask_updater.switch_to_mask_update_mode(self.sparse_layer) + self.maks_update_mode = True + + def prepare_for_weights_update(self): + if self.mask_updater is not None: + self.mask_updater.switch_to_weights_update_mode(self.sparse_layer) + self.maks_update_mode = False + + +@dataclass +class ScatteredConfig(SparseMaskConfig): + pass + + +@Modifier.register("scattered_sparse_adapter", config_cls=ScatteredConfig) +class ScatteredSparseAdapter(SparseMaskAdapter): + """ + Sparse adapter that only keeps non-zero weights around as parameters. + """ + + def __init__( + self, + config: ScatteredConfig, + layer: nn.Module, + **kwargs, + ): + super().__init__(config, layer, **kwargs) + self.sparse_layer: SparseLinear = ScatteredSparseLinearModule( + self.dense_layer_weight, + self.dense_layer_bias, + self.config, + parent_name=self.name, + ) + + +@dataclass +class MLSConfig(SparseMaskConfig): + init_all_ones: bool = False + + +@Modifier.register("mls_sparse_adapter", config_cls=MLSConfig) +class MaskedLinearSparseAdapter(SparseMaskAdapter): + """ + Sparse adapter that keeps the sparse weights as dense matrix. + """ + + def __init__( + self, + config: MLSConfig, + layer: nn.Module, + **kwargs, + ): + super().__init__(config, layer, **kwargs) + self.sparse_layer: SparseLinear = MaskedLinear( + self.dense_layer_weight, + self.dense_layer_bias, + self.config, + parent_name=self.name, + init_all_ones=config.init_all_ones, + ) diff --git a/mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py b/mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py new file mode 100644 index 000000000..eb317f381 --- /dev/null +++ b/mttl/models/modifiers/sparsity/sparse_utils/bsr_ddsloop_benchmark.py @@ -0,0 +1,668 @@ +import logging +import re +import time +from typing import List + +import numpy as np +import pandas as pd +import stk.ops +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton as tn +from pytorch_lightning import seed_everything +from spops import csr_add, spmm +from triton.ops.blocksparse import matmul + +from mttl.logging import logger +from mttl.models.modifiers import modify_transformer +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig +from mttl.models.modifiers.sparse_mask import ( + MaskedLinear, + ScatteredSparseLinearModule, + SparseLinearModule, + SparseMaskAdapter, + SparseMaskConfig, +) +from mttl.models.utils import model_loader_helper, transfer_batch_to_device + +device = "cuda" +logger.setLevel(logging.ERROR) +block_size = 16 # 128 # 16 +n_blocks = 1024 # 16 # 1024 + +in_d = 2048 +out_d = 8192 +dtype = torch.bfloat16 + +# input sizes and batch sizes for testing +max_seq_len = 1024 +bs = 5 + + +layer = nn.Linear(in_d, out_d).to(device) +layer.weight.requires_grad_(False) +layer.bias.requires_grad_(False) +K = 10 + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + + +def find_hyperpaams(): + modules = {"linear": layer} + modified_modules = {} + keep_ratios = [] + lora_ranks = [] + + for name, module in modules.items(): + keep_ratio = ( + n_blocks * (block_size**2) / (module.in_features * module.out_features) + ) + tot_sparse_params = module.in_features * module.out_features * keep_ratio + lora_rank = 1 + for rank in range(1, module.in_features): + lora_params = calculate_lora_parameters( + module.in_features, module.out_features, rank + ) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + modified_modules[name] = { + "module": module, + "keep_ratio": keep_ratio, + "lora_rank": lora_rank, + } + keep_ratios.append(keep_ratio) + lora_ranks.append(lora_rank) + return np.mean(keep_ratios), int(np.mean(lora_ranks)) + + +keep_ratio, lora_rank = find_hyperpaams() +print( + f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}, Lora params: {calculate_lora_parameters(in_d, out_d, lora_rank)}, Sparse params: {in_d * out_d * keep_ratio}" +) +x = torch.randn(bs, max_seq_len, in_d, dtype=dtype, device=device) + + +def create_adapter_set(adapter_config, layer, K) -> List[Modifier]: + if isinstance(adapter_config, SparseMaskConfig): + layer = nn.Linear(out_d, in_d) # TODO: implement transpose in SparseWeights + module = [SparseMaskAdapter(adapter_config, layer) for _ in range(K)] + elif isinstance(adapter_config, LoRAConfig): + module = [LoRA(adapter_config, layer) for _ in range(K)] + return module + + +@torch.autocast(device_type="cuda", dtype=dtype) +def lora_merge(lora_a, lora_b, x, W_base, W_merge): + + # merge into 1 loa + A = torch.einsum("ble,edr->bldr", (W_merge, lora_a)) + B = torch.einsum("ble,erd->blrd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bld,bldr->blr", (x, A)) + adapter_out = torch.einsum("blr,blrd->bld", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward(sparse_weights, x, W_base, W_merge): + """ + Perform the merging of sparse adapters and compute the forward pass. This uses torch dds mm. + + Parameters: + - sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in CSR format. + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = ( + x_flat @ S_k + ) # Shape: [bs * max_seq_len, output_dim] <- this is dds mm + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blck_sparse_merge_and_forward(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = F.linear(x_flat, S_k) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_with_SpMM(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + + # Compute the output for this adapter + output_k = spmm( + S_k.sparse_weights, + S_k.row_offs, + S_k.row_idx, + S_k.col_idx, + x_flat.T.contiguous(), + S_k.shape[0], + backend="sputnik", + ) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k.T * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_with_spadd(sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_w = torch.zeros(sparse_weights[0].shape).to(device) + + # Iterate over each adapter + for k in range(K): + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + S_k = sparse_weights[k] + + # Compute the output for this adapter + adapter_w = csr_add( + S_k.sparse_weights * W_k, S_k.row_offs, S_k.row_idx, S_k.col_idx, adapter_w + ) + + adapter_out = x_flat @ adapter_w + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def sparse_merge_and_forward_vectorized(sparse_weights, x, W_base, W_merge): + """ + Perform the merging of sparse adapters and compute the forward pass. + + Parameters: + - sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in CSR format. + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Stack and expand sparse weights + # Convert sparse weights to dense (if memory allows) + sparse_weights_dense = torch.stack( + [S.to_dense() for S in sparse_weights], dim=0 + ) # [K, input_dim, output_dim] + + # Compute adapter outputs + # [bs*max_seq_len, K, output_dim] + adapter_out = torch.einsum("bi,kio->bko", x_flat, sparse_weights_dense) + W_merge_flat = W_merge.reshape(bs * max_seq_len, K, 1) # [bs*max_seq_len, K, 1] + adapter_out = (adapter_out * W_merge_flat).sum( + dim=1 + ) # [bs*max_seq_len, output_dim] + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blk_sparse_merge_and_forward_triton( + block_sparse_ops, block_sparse_weights, x, W_base, W_merge +): + """ + Perform the merging of sparse adapters and compute the forward pass. This uses triton dds kernel with precomputed layour (see prepare_triton_bs_op). + + Parameters: + - block_sparse_ops: List[triton.ops.blocksparse.matmul], each of shape [input_dim, output_dim] in CSR format. + - block_sparse_weights: List[torch.Tensor], each of shape [input_dim, output_dim] in BSR format (these are only non-zero blocks). + - x: torch.Tensor, input of shape [bs, max_seq_len, input_dim]. + - W_base: torch.Tensor, base model weights of shape [input_dim, output_dim]. + - W_merge: torch.Tensor, merging weights of shape [bs, max_seq_len, K]. + + Returns: + - y: torch.Tensor, output of shape [bs, max_seq_len, output_dim]. + """ + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = block_sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + _x_flat = x_flat.unsqueeze(0).unsqueeze(0).contiguous() + # Compute the output for this adapter + output_k = block_sparse_ops[k]( + _x_flat, S_k + ).squeeze() # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +@torch.autocast(device_type="cuda", dtype=dtype) +def blk_sparse_merge_and_forward_stk(block_sparse_weights, x, W_base, W_merge): + bs, max_seq_len, input_dim = x.shape + output_dim = W_base.shape[1] + K = W_merge.shape[2] + + device = x.device + dtype = x.dtype + + # Flatten x for efficient computation + x_flat = x.reshape(bs * max_seq_len, input_dim) + + # Compute base output + base_out = x_flat @ W_base + + # Initialize adapter output + adapter_out = torch.zeros_like(base_out) + + # Iterate over each adapter + for k in range(K): + S_k = block_sparse_weights[k] # Sparse matrix of shape [input_dim, output_dim] + # Compute the output for this adapter + output_k = stk.ops.dds(x_flat, S_k) # Shape: [bs * max_seq_len, output_dim] + + # Get the merging weights for this adapter + W_k = W_merge[:, :, k].reshape( + bs * max_seq_len, 1 + ) # Shape: [bs * max_seq_len, 1] + + # Scale and accumulate the adapter output + adapter_out += output_k * W_k + + # Sum base and adapter outputs + y_flat = base_out + adapter_out + + # Reshape back to [bs, max_seq_len, output_dim] + y = y_flat.reshape(bs, max_seq_len, output_dim) + + return y + + +sparse_merge_and_forward_compiled = torch.compile(sparse_merge_and_forward) +lora_merge_compiled = torch.compile(lora_merge) + +# adapter_config_sm = SparseMaskConfig( +# sps_impl="scattered", +# sps_type="regular_sparse", +# keep_ratio=keep_ratio, +# reselection_steps=1, +# block_size=block_size, +# ) + + +adapter_config_lora = LoRAConfig(modify_layers="", lora_rank=lora_rank) +adapter_config_bs = SparseMaskConfig( + sps_impl="scattered", + sps_type="block_sparse", + keep_ratio=keep_ratio, + reselection_steps=1, + block_size=block_size, +) + + +def bsr_to_binary_layout(bsr_matrix, block_size): + # Get the shape of the BSR matrix + M, K = bsr_matrix.shape + + # Number of blocks along rows and columns + num_block_rows = M // block_size + num_block_cols = K // block_size + + # Initialize the binary layout matrix with zeros + binary_layout = torch.zeros((num_block_rows, num_block_cols), dtype=int) + + # Get BSR matrix data + block_row_indices = bsr_matrix.col_indices() + block_row_pointers = bsr_matrix.crow_indices() + + # Iterate over the block rows + for block_row in range(num_block_rows): + # Iterate over the non-zero blocks in the current block row + for idx in range( + block_row_pointers[block_row], block_row_pointers[block_row + 1] + ): + block_col = block_row_indices[idx] + # Mark the block as non-zero + binary_layout[block_row, block_col] = 1 + + return binary_layout + + +def prepare_triton_bs_op(W, op_mode): + Z, H = 1, 1 + AT = False + BT = False + + layout = bsr_to_binary_layout(W, block_size).unsqueeze(0) + # creat inputs + op = matmul(layout, block_size, op_mode, trans_a=AT, trans_b=BT, device="cuda") + return op + + +@tn.testing.perf_report( + tn.testing.Benchmark( + x_names=["K"], # Argument names to use as an x-axis for the plot. + x_vals=[2, 3, 4, 10, 64, 128], # Different possible values for `x_name`. + x_log=False, # x axis is logarithmic. + line_arg="provider", # Argument name whose value corresponds to a different line in the plot. + line_vals=[ + "stk", + "triton_blck_sparse", + "lora", + "torch_sparse", + "torch_block_sparse", + ], # "lora_compiled", "torch_sparse_compiled"], # Possible values for `line_arg`. + line_names=[ + "stk", + "triton_blck_sparse", + "lora", + "torch_sparse", + "torch_block_sparse", + ], # "lora_compiled", "torch_sparse_compiled"], # Label name for the lines. + styles=[ + ("blue", "-"), + ("green", "-"), + ("orange", "-"), + ("red", "-"), + ("purple", "-"), + ("black", "-"), + ("brown", "-"), + ], # Line color and style. + ylabel="ms", #'GB/s', # Label name for the y-axis. + xlabel="K", + plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. + args={"bs": bs, "max_seq_len": max_seq_len, "in_d": in_d, "d_out": out_d}, + ) +) +def benchmark(K, bs, max_seq_len, in_d, d_out, provider): + W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) + loras = create_adapter_set(adapter_config_lora, layer, K) + sparse_modules = create_adapter_set(adapter_config_bs, layer, K) + W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) + + lora_a = torch.stack([lora.lora_a for lora in loras], dim=0) + lora_b = torch.stack([lora.lora_b for lora in loras], dim=0) + sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().to_sparse_csr().to(device) + for sparse_module in sparse_modules + ] + sparse_weights_spops = [ + sparse_module.sparse_layer.to(device) for sparse_module in sparse_modules + ] + + print("Testing provider:", provider, "K:", K) + quantiles = [0.5, 0.2, 0.8] + if provider == "lora": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: lora_merge(lora_a, lora_b, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "lora_compiled": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: lora_merge_compiled(lora_a, lora_b, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "torch_sparse": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward(sparse_weights, x, layer.weight.T, W_mege), + quantiles=quantiles, + ) + elif provider == "torch_block_sparse": + block_sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().T.to_sparse_bsr(block_size).to(device) + for sparse_module in sparse_modules + ] + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blck_sparse_merge_and_forward( + block_sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "torch_sparse_compiled": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_compiled( + sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "sparse_vectorized": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_vectorized( + sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "sparse_spadd": + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: sparse_merge_and_forward_with_spadd( + sparse_weights_spops, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + elif provider == "triton_blck_sparse": + block_sparse_weights: List[torch.Tensor] = [ + sparse_module.sparse_layer.to_dense().to_sparse_bsr(block_size).to(device) + for sparse_module in sparse_modules + ] + # create a list of ops with precomputed layouts for the BSR matrices + block_sparse_ops = [ + prepare_triton_bs_op(sparse_w, "dds") for sparse_w in block_sparse_weights + ] + # block_sparse_weights_as_dense = [ + # sparse_w.to_dense() + # .to(dtype) + # .reshape(-1, block_size, block_size) + # .unsqueeze(0) + # .contiguous() + # for sparse_w in block_sparse_weights + # ] + block_sparse_weights = [ + sparse_w.values().to(dtype).unsqueeze(0).contiguous() + for sparse_w in block_sparse_weights + ] + + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blk_sparse_merge_and_forward_triton( + block_sparse_ops, + block_sparse_weights, + x, + layer.weight.T, + W_mege, + ), + quantiles=quantiles, + ) + + elif provider == "stk": + # only supports block_size = 128 and float16 + if block_size != 128: + ms, min_ms, max_ms = 0, 0, 0 + else: + block_sparse_weights = [] + for sparse_module in sparse_modules: + W = sparse_module.sparse_layer.to_dense().to(device).to(torch.float16) + W_stk = stk.ops.to_sparse(W, blocking=block_size) + W_stk.validate() + block_sparse_weights.append(W_stk) + ms, min_ms, max_ms = tn.testing.do_bench( + lambda: blk_sparse_merge_and_forward_stk( + block_sparse_weights, x, layer.weight.T, W_mege + ), + quantiles=quantiles, + ) + + # gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) + # return gbps(ms), gbps(max_ms), gbps(min_ms) + return ms, max_ms, min_ms + + +benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py b/mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py new file mode 100644 index 000000000..579cbe532 --- /dev/null +++ b/mttl/models/modifiers/sparsity/sparse_utils/bsr_moe_benchmark.py @@ -0,0 +1,248 @@ +import logging +import re +import time +from typing import List + +import numpy as np +import pandas as pd +import stk.ops +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton as tn +from pytorch_lightning import seed_everything +from spops import csr_add, spmm +from stk.matrix import Matrix +from triton.ops.blocksparse import matmul + +from mttl.logging import logger +from mttl.models.modifiers import modify_transformer +from mttl.models.modifiers.base import Modifier +from mttl.models.modifiers.lora import LoRA, LoRAConfig, SkilledLoRA, SkilledLoRAConfig +from mttl.models.modifiers.sparse_mask import ( + MaskedLinear, + ScatteredSparseLinearModule, + SparseLinearModule, + SparseMaskAdapter, + SparseMaskConfig, +) +from mttl.models.modifiers.sparsity.sparse_utils.utils import ( + padded_gather, + padded_scatter, +) +from mttl.models.modifiers.sparsity.spb_moe import _matrix_ops, ops +from mttl.models.utils import model_loader_helper, transfer_batch_to_device + +device = "cuda" +logger.setLevel(logging.ERROR) +block_size = 16 # 128 # 16 +n_blocks = 1024 # 16 # 1024 +in_d = 2048 +out_d = 8192 +dtype = torch.bfloat16 +max_seq_len = 1024 +bs = 2 +layer = nn.Linear(in_d, out_d).to(device) +layer.weight.requires_grad_(False) +layer.bias.requires_grad_(False) +K = 100 +top_k = 2 + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + + +def find_hyperpaams(): + modules = {"linear": layer} + modified_modules = {} + keep_ratios = [] + lora_ranks = [] + + for name, module in modules.items(): + keep_ratio = ( + n_blocks * (block_size**2) / (module.in_features * module.out_features) + ) + tot_sparse_params = module.in_features * module.out_features * keep_ratio + lora_rank = 1 + for rank in range(1, module.in_features): + lora_params = calculate_lora_parameters( + module.in_features, module.out_features, rank + ) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + modified_modules[name] = { + "module": module, + "keep_ratio": keep_ratio, + "lora_rank": lora_rank, + } + keep_ratios.append(keep_ratio) + lora_ranks.append(lora_rank) + return np.mean(keep_ratios), int(np.mean(lora_ranks)) + + +keep_ratio, lora_rank = find_hyperpaams() +print( + f"Keep ratio: {keep_ratio}, LoRA rank: {lora_rank}, Lora params: {calculate_lora_parameters(in_d, out_d, lora_rank)}, Sparse params: {in_d * out_d * keep_ratio}" +) +x = torch.randn(bs, max_seq_len, in_d, dtype=dtype, device=device).contiguous() + + +def create_adapter_set(adapter_config, layer, K) -> List[Modifier]: + if isinstance(adapter_config, SparseMaskConfig): + layer = nn.Linear(out_d, in_d) # TODO: implement transpose in SparseWeights + module = [SparseMaskAdapter(adapter_config, layer) for _ in range(K)] + elif isinstance(adapter_config, LoRAConfig): + module = [LoRA(adapter_config, layer) for _ in range(K)] + return module + + +def sparsemodules_to_stkmatrix_list(sparse_modules): + sparse_weights = [] + for sparse_module in sparse_modules: + mtx = stk.ops.to_sparse( + sparse_module.sparse_layer.to_dense().type(dtype), blocking=block_size + ) + # mtx.validate() + sparse_weights.append(mtx) + return sparse_weights + + +@torch.autocast(device_type="cuda", dtype=dtype) +def lora_merge(lora_a, lora_b, x, W_base, W_merge): + + # merge into 1 loa + A = torch.einsum("ble,edr->bldr", (W_merge, lora_a)) + B = torch.einsum("ble,erd->blrd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bld,bldr->blr", (x, A)) + adapter_out = torch.einsum("blr,blrd->bld", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + +def create_block_diagonal_matrix(bs_m, bs_n, n_blocks): + assert bs_m >= block_size + assert bs_n >= block_size + factor = (bs_m * bs_n) // (block_size**2) + + M = bs_m * n_blocks + N = bs_n * n_blocks + + Mb = M // block_size + Nb = N // block_size + + nb_m_pb = bs_m // block_size + nb_n_pb = bs_n // block_size + + col_indices_1blk = torch.arange(nb_n_pb, device=device, dtype=torch.int32).repeat( + nb_m_pb + ) + row_indices_1blk = torch.arange( + nb_m_pb, device=device, dtype=torch.int32 + ).repeat_interleave(nb_n_pb) + offsets = torch.arange(0, Mb * nb_n_pb + nb_n_pb, nb_n_pb, device=device) + + col_idx = torch.cat([col_indices_1blk + i * nb_n_pb for i in range(n_blocks)]) + row_idx = torch.cat([row_indices_1blk + i * nb_m_pb for i in range(n_blocks)]) + data = torch.empty((Mb * Nb, block_size, block_size), device=device) + + return Matrix((M, N), data, row_idx, col_idx, offsets) + + +adapter_config_lora = LoRAConfig(modify_layers="", lora_rank=lora_rank) +adapter_config_bs = SparseMaskConfig( + sps_impl="scattered", + sps_type="block_sparse", + keep_ratio=keep_ratio, + reselection_steps=1, + block_size=block_size, +) + +# FOWARD PASS through MoE +W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) +loras = create_adapter_set(adapter_config_lora, layer, K) +sparse_modules = create_adapter_set(adapter_config_bs, layer, K) +sparse_mtxs = sparsemodules_to_stkmatrix_list(sparse_modules) +adaptersMatrix: Matrix = _matrix_ops.merge_adapters(sparse_mtxs).to(device) + +W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) +top_k_indices = torch.topk(torch.abs(W_mege), top_k, dim=-1).indices +( + x, + num_tokens_per_expert, + sort_order, + indices_expert_padded, + positions_in_expert_padded, + padding_mask, +) = padded_gather(x, top_k_indices, K) +layout = _matrix_ops.create_ada_layout(adaptersMatrix).to(device) + +out_blck_size = x.shape[1] +x = x.reshape(-1, in_d).contiguous() +out_topology = create_block_diagonal_matrix(out_blck_size, out_d, K) +W_base = layer.weight.T.to(dtype=dtype) +output = ops.sdd_adamerge(x, W_base, out_topology, adaptersMatrix, layout) +print(output.shape) +# create output topoly + + +# @tn.testing.perf_report( +# tn.testing.Benchmark( +# x_names=["K"], # Argument names to use as an x-axis for the plot. +# x_vals=[2, 3, 4, 10, 64, 128], # Different possible values for `x_name`. +# x_log=False, # x axis is logarithmic. +# line_arg="provider", # Argument name whose value corresponds to a different line in the plot. +# line_vals=[ +# "lora", +# ], # "lora_compiled", "torch_sparse_compiled"], # Possible values for `line_arg`. +# line_names=[ +# "lora", +# ], # "lora_compiled", "torch_sparse_compiled"], # Label name for the lines. +# styles=[ +# ("blue", "-"), +# ("green", "-"), +# ("orange", "-"), +# ("red", "-"), +# ("purple", "-"), +# ("black", "-"), +# ("brown", "-"), +# ], # Line color and style. +# ylabel="ms", #'GB/s', # Label name for the y-axis. +# xlabel="K", +# plot_name="matmul-performance", # Name for the plot. Used also as a file name for saving the plot. +# args={"bs": bs, "max_seq_len": max_seq_len, "in_d": in_d, "d_out": out_d}, +# ) +# ) +# def benchmark(K, bs, max_seq_len, in_d, d_out, provider): +# W_mege = torch.randn(bs, max_seq_len, K, dtype=dtype, device=device) +# loras = create_adapter_set(adapter_config_lora, layer, K) +# sparse_modules = create_adapter_set(adapter_config_bs, layer, K) +# W_mege = W_mege.to(dtype=loras[0].lora_a.dtype) + +# lora_a = torch.stack([lora.lora_a for lora in loras], dim=0) +# lora_b = torch.stack([lora.lora_b for lora in loras], dim=0) +# sparse_weights: List[torch.Tensor] = [ +# sparse_module.sparse_layer.to_dense().to_sparse_csr().to(device) +# for sparse_module in sparse_modules +# ] +# sparse_weights_spops = [ +# sparse_module.sparse_layer.to(device) for sparse_module in sparse_modules +# ] + +# print("Testing provider:", provider, "K:", K) +# quantiles = [0.5, 0.2, 0.8] +# if provider == "lora": +# ms, min_ms, max_ms = tn.testing.do_bench( +# lambda: lora_merge(lora_a, lora_b, x, layer.weight.T, W_mege), +# quantiles=quantiles, +# ) + +# # gbps = lambda ms: 2 * s * h * o * 2 * 1e-9 / (ms * 1e-3) +# # return gbps(ms), gbps(max_ms), gbps(min_ms) +# return ms, max_ms, min_ms + + +# benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py b/mttl/models/modifiers/sparsity/sparse_utils/csr_add_vs_scatter_add.py similarity index 100% rename from mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py rename to mttl/models/modifiers/sparsity/sparse_utils/csr_add_vs_scatter_add.py diff --git a/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_utils.py b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_utils.py new file mode 100644 index 000000000..9f7b7e9f4 --- /dev/null +++ b/mttl/models/modifiers/sparsity/sparse_utils/stk_matrix_utils.py @@ -0,0 +1,119 @@ +from typing import List + +import numpy as np +import stk.ops +import torch +from stk.matrix import Matrix + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return ( + dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True), + ) + + +def _merge_adapters(adapters: List[Matrix]) -> Matrix: + """ + Merges a list of adapters into a single adapter along the second dimention. + Also changes the block size by padding blocks iwht 0s if necessary. + + """ + col_indices_list = [adap.column_indices.to(torch.int32) for adap in adapters] + # row_indices_list = [adap.row_indices for adap in adapters] + offsets_list = [adap.offsets for adap in adapters] + data_list = [adap.data for adap in adapters] + + num_rows = [offsets.numel() - 1 for offsets in offsets_list] + assert all( + num_rows[0] == num_rows[i] for i in range(1, len(num_rows)) + ), "All adapters must have the same number of rows" + + block_size = adapters[0].blocking + + K, N = adapters[0].size() + col_offset = N // block_size # assuming all have same number of cols + n_adaps = len(adapters) + + adjusted_col_indices = [] + for e, col_idx in enumerate(col_indices_list): + adjusted_col_indices.append(col_idx + e * col_offset) + + merged_col_indices = torch.cat(adjusted_col_indices) + row_indices = torch.cat( + [adap.row_indices.to(torch.int32) for adap in adapters], dim=0 + ) + data = torch.cat([adap.data for adap in adapters], dim=0) + + indices = torch.stack([row_indices, merged_col_indices], dim=1) + + if indices.is_cuda: + indices = indices.cpu() + + # Convert to NumPy + np_tensor = indices.numpy() + # Perform lexsort: sort by second key first, then first key + sorted_indices = np.lexsort((np_tensor[:, 1], np_tensor[:, 0])) + + data = data[sorted_indices].contiguous() + row_indices = row_indices[sorted_indices].contiguous() + col_indices = merged_col_indices[sorted_indices].contiguous() + + # recalculate offsets + num_rows = max(num_rows) + offsets = torch.zeros(num_rows + 1, dtype=torch.int32, device=row_indices.device) + counts_per_row = torch.bincount(row_indices, minlength=num_rows) + offsets[1:] = torch.cumsum(counts_per_row, dim=0) + offsets = offsets.contiguous() + + return Matrix((K, n_adaps * N), data, row_indices, col_indices, offsets) + + +def change_block_size(M: Matrix, new_blk_size) -> Matrix: + raise NotImplementedError("change_block_size is not implemented yet") + return + + +def merge_adapters(adapters: List[Matrix], blk_size=None) -> Matrix: + """ + Merges a list of adapters into a single adapter along the second dimention. + Also changes the block size by padding blocks iwht 0s if necessary. + + """ + + out = _merge_adapters( + adapters + ) # merges the adapters into a single Matrix() without changing the block size + if blk_size is not None: + out = change_block_size(out, blk_size) + return out + + +def create_ada_layout(matix: Matrix): + """ + Creates a binary tensor that identifies if block exists in the adapter matrix + """ + block_size = matix.blocking + layout = ( + torch.ones( + (matix.size()[0] // block_size, matix.size()[1] // block_size), + dtype=torch.int32, + device=matix.device, + ) + * -1 + ) + blck = 0 + for r, c in zip(matix.row_indices, matix.column_indices): + layout[r.item(), c.item()] = blck + blck += 1 + return layout.contiguous() diff --git a/mttl/models/modifiers/sparse_utils/utils.py b/mttl/models/modifiers/sparsity/sparse_utils/utils.py similarity index 77% rename from mttl/models/modifiers/sparse_utils/utils.py rename to mttl/models/modifiers/sparsity/sparse_utils/utils.py index d8762bb7d..31c168818 100644 --- a/mttl/models/modifiers/sparse_utils/utils.py +++ b/mttl/models/modifiers/sparsity/sparse_utils/utils.py @@ -655,3 +655,168 @@ def backward(ctx, output_grad): return tuple(grads) else: return (grads[0], None) + tuple(grads[2:]) + + +import torch + + +def padded_gather(x, indices, E, block_size=16): + """ + Permute tokens to group them by expert. + Ensures that the number of tokens per expert is divisible by block_size by adding padding. + Returns additional data for padded_scatter. + """ + batch_size, seq_len, d_model = x.size() + top_k = indices.size(-1) + + # Step 1: Flatten x and indices + x_flat = x.view(-1, d_model) # [batch_size * seq_len, d_model] + indices_flat = indices.view(-1) # [batch_size * seq_len * top_k] + + # Step 2: Expand x to match indices + x_flat_expanded = x_flat.unsqueeze(1).expand( + -1, top_k, -1 + ) # [batch_size * seq_len, top_k, d_model] + x_expert = x_flat_expanded.reshape( + -1, d_model + ) # [batch_size * seq_len * top_k, d_model] + + # Step 3: Sort indices and x_expert to group tokens by expert + indices_expert, sort_order = indices_flat.sort() + x_expert_sorted = x_expert[sort_order] + + # Step 4: Compute number of tokens per expert + num_tokens_per_expert = torch.bincount(indices_expert, minlength=E) # [E] + + # Step 5: Compute padded number of tokens per expert + padded_num_tokens_per_expert = ( + (num_tokens_per_expert + block_size - 1) // block_size + ) * block_size # [E] + max_tokens_per_expert = padded_num_tokens_per_expert.max().item() + + # Step 6: Compute positions within each expert + def compute_positions_in_group(indices_expert): + unique_indices, counts = indices_expert.unique_consecutive(return_counts=True) + positions_in_expert = torch.cat( + [torch.arange(count, device=indices_expert.device) for count in counts] + ) + return positions_in_expert + + positions_in_expert = compute_positions_in_group(indices_expert) + + # Step 7: Pad the tokens per expert to make counts divisible by block_size + # For each expert, determine padding needed + padding_needed = padded_num_tokens_per_expert - num_tokens_per_expert # [E] + + indices_expert_padded = [] + positions_in_expert_padded = [] + x_expert_padded = [] + padding_mask = [] + + current_idx = 0 + for e in range(E): + count = num_tokens_per_expert[e].item() + padded_count = padded_num_tokens_per_expert[e].item() + padding = padding_needed[e].item() + + # Get the indices and positions for the current expert + indices_e = indices_expert[current_idx : current_idx + count] + positions_e = positions_in_expert[current_idx : current_idx + count] + x_expert_e = x_expert_sorted[current_idx : current_idx + count] + + # Append original tokens + indices_expert_padded.append(indices_e) + positions_in_expert_padded.append(positions_e) + x_expert_padded.append(x_expert_e) + padding_mask.append( + torch.ones(count, dtype=torch.bool, device=indices_expert.device) + ) + + # If padding is needed, duplicate the last token 'padding' times + if padding > 0: + indices_e_pad = indices_e.new_full((padding,), fill_value=e) + positions_e_pad = positions_e.new_tensor(range(count, padded_count)) + # For x_expert, duplicate the last token + x_expert_e_pad = x_expert_e[-1:].expand(padding, -1) # Duplicate last token + + indices_expert_padded.append(indices_e_pad) + positions_in_expert_padded.append(positions_e_pad) + x_expert_padded.append(x_expert_e_pad) + padding_mask.append( + torch.zeros(padding, dtype=torch.bool, device=indices_expert.device) + ) + + current_idx += count + + # Concatenate all the padded indices, positions, tokens, and mask + indices_expert_padded = torch.cat(indices_expert_padded, dim=0) + positions_in_expert_padded = torch.cat(positions_in_expert_padded, dim=0) + x_expert_padded = torch.cat(x_expert_padded, dim=0) + padding_mask = torch.cat(padding_mask, dim=0) # [total_padded_tokens] + + # Step 8: Initialize output tensor + output = x.new_zeros(E, max_tokens_per_expert, d_model) + + # Step 9: Assign tokens to output tensor + output[indices_expert_padded, positions_in_expert_padded] = x_expert_padded + + # Return additional information for padded_scatter + return ( + output, + num_tokens_per_expert, + sort_order, + indices_expert_padded, + positions_in_expert_padded, + padding_mask, + ) + + +def padded_scatter( + x, num_tokens_per_expert, sort_order, batch_size, seq_len, top_k, d_model +): + """ + Un-permute tokens back to their original positions. + + Args: + x (torch.Tensor): Input tensor of shape [E, max_tokens_per_expert, d_model], outputs from experts. + num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert, shape [E]. + sort_order (torch.Tensor): The sort order used in padded_gather. + batch_size (int): Original batch size. + seq_len (int): Original sequence length. + top_k (int): Number of experts per token. + d_model (int): Model dimension. + + Returns: + output (torch.Tensor): Un-permuted tokens, shape [batch_size, seq_len, top_k, d_model]. + """ + E, max_tokens_per_expert, _ = x.size() + device = x.device + + # Step 1: Flatten x and remove padding + x_flat = x.view(-1, d_model) # [E * max_tokens_per_expert, d_model] + + # Step 2: Build indices for valid tokens + expert_indices = torch.repeat_interleave( + torch.arange(E, device=device), num_tokens_per_expert + ) + positions_in_expert = torch.cat( + [torch.arange(n, device=device) for n in num_tokens_per_expert] + ) + valid_positions = expert_indices * max_tokens_per_expert + positions_in_expert + + # Step 3: Select valid tokens + x_valid = x_flat[valid_positions] + + # Step 4: Reconstruct x_expert_sorted + x_expert_sorted = x_valid + + # Step 5: Reconstruct x_expert using inverse of sort_order + x_expert = torch.empty( + (batch_size * seq_len * top_k, d_model), device=device, dtype=x.dtype + ) + x_expert[sort_order] = x_expert_sorted + + # Step 6: Reshape to [batch_size, seq_len, top_k, d_model] + x_unpermuted = x_expert.view(batch_size, seq_len, top_k, d_model) + + return x_unpermuted diff --git a/mttl/models/modifiers/sparsity/spb_moe/__init__.py b/mttl/models/modifiers/sparsity/spb_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mttl/models/modifiers/sparsity/spb_moe/benchmark.py b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py new file mode 100644 index 000000000..5c57ecc14 --- /dev/null +++ b/mttl/models/modifiers/sparsity/spb_moe/benchmark.py @@ -0,0 +1,200 @@ +import time +from functools import partial + +import numpy as np +import stk +import torch +from pytorch_lightning import seed_everything +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.spb_moe import ops + + +def dumb_forward(base_act, x, expert_p, expert_idxs, adaps): + output = torch.stack( + [ + sum( + base_act[i] + + (expert_p[i, j] * torch.matmul(x[i], (adaps[expert_idxs[i, j]]))) + for j in range(expert_idxs.size(1)) + ) + for i in range(expert_idxs.size(0)) + ], + dim=0, + ) + return output + + +def benchmark_module(name, function, runs=100): + # Warm-up to ensure accurate measurement + for _ in range(10): + out = function() + + forward_time_total = 0.0 + + # Benchmark runs + for _ in range(runs): + # Forward pass timing + torch.cuda.synchronize() + start_time = time.time() + out = function() + torch.cuda.synchronize() + forward_time = time.time() - start_time + + # Accumulate times + forward_time_total += forward_time + + avg_forward_time = forward_time_total / runs + + # Measure memory usage + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + out = function() # Forward pass to record memory + memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.max_memory_reserved() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + print( + f"Average forward time {name}: {avg_forward_time:.6f}s", + f"Memory allocated: {memory_allocated/1024**2:.2f}MB", + f"Memory reserved: {memory_reserved/1024**2:.2f}MB", + ) + + +def calculate_lora_parameters(input_dim, output_dim, rank): + return input_dim * rank + output_dim * rank + + +def find_lora_hyperpaams(d_in, d_out, tot_sparse_params): + lora_ranks = [] + lora_rank = 1 + for rank in range(1, d_in): + lora_params = calculate_lora_parameters(d_in, d_out, rank) + if lora_params <= tot_sparse_params: + lora_rank = rank + else: + break + lora_ranks.append(lora_rank) + return int(np.mean(lora_ranks)) + + +MOE_TESTCASES = { + # bs, d, h, E, k, sparsity, blocking, dtype + (1024, 2048, 8192, 20, 2, 0.995, 16, torch.float16), + (1024, 2048, 8192, 20, 2, 0.9, 128, torch.float16), + (1024, 2048, 8192, 100, 2, 0.995, 16, torch.float16), + (1024, 2048, 8192, 100, 2, 0.9, 128, torch.float16), + # (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float16), + # (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float16), + # (8, 128, 256, 10, 2, 0.8, 16, torch.float16), +} + +if __name__ == "__main__": + for bs, d, h, E, k, sparsity, blocking, dtype in MOE_TESTCASES: + print("=====================================================================") + print( + f"***** Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype} *****" + ) + + torch.manual_seed(42) + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).cuda().to(dtype) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).cuda() + W = torch.randn(d, h, dtype=dtype, requires_grad=True).cuda() + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) + for _ in range(E) + ] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack( + [adap.column_indices_t for adap in adaps_sparse], dim=0 + ) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + + def call_with_baseact_and_idxs_computation( + X, W, expert_idxs, function, **kwargs + ): + base_act = torch.matmul(X, W) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort( + expert_idxs + ) + padded_block_idxs, expert_offsets = ops.padded_block_indices( + sorted_expert_idxs, E + ) + return function( + x=X, + base_act=base_act, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + **kwargs, + ) + + # base_act = torch.matmul(X, W) + func = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=ops.scattergather_adamerge, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel not optimized", func) + # func_dummb = partial(dumb_forward, base_act=base_act, x=X, expert_p=k_weights, expert_idxs=expert_idxs, adaps=adaps_dense) + # benchmark_module("dummy forward", func_dummb) + + func_opt = partial( + call_with_baseact_and_idxs_computation, + X=X, + W=W, + expert_idxs=expert_idxs, + function=ops.scattergather_adamerge_opt, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + gates=k_weights, + ) + benchmark_module("BS kernel optimized", func) + lora_rank = find_lora_hyperpaams(d, h, np.prod(ada_data.shape[1:])) + + def lora_merge(lora_a, lora_b, x, W_base, W_merge): + # LoRA does not profit from lower top-k in this vanila form + # merge into 1 lora + A = torch.einsum("be,edr->bdr", (W_merge, lora_a)) + B = torch.einsum("be,erd->brd", (W_merge, lora_b)) + # lora forward + partial_out = torch.einsum("bd,bdr->br", (x, A)) + adapter_out = torch.einsum("br,brd->bd", (partial_out, B)) + dense_out = x @ W_base + return adapter_out + dense_out + + lora_a = torch.randn(E, d, lora_rank, dtype=dtype).cuda().contiguous() + lora_b = torch.randn(E, lora_rank, h, dtype=dtype).cuda().contiguous() + func_lora = partial( + lora_merge, lora_a=lora_a, lora_b=lora_b, x=X, W_base=W, W_merge=weights + ) + benchmark_module("LoRA merge (our current vanila)", func_lora) diff --git a/mttl/models/modifiers/sparsity/spb_moe/functions.py b/mttl/models/modifiers/sparsity/spb_moe/functions.py new file mode 100644 index 000000000..023e5abae --- /dev/null +++ b/mttl/models/modifiers/sparsity/spb_moe/functions.py @@ -0,0 +1,100 @@ +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.spb_moe.triton_kernels import ( + scatter2scatter_sparse, + scatter2scatter_sparse_optimized, +) + + +class ParalleLinear(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx, + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, + ): + + output = scatter2scatter_sparse( + X=x, + base_act=base_act, + ada_weights=ada_weights, + row_idxs=row_idxs, + col_idxs_t=col_idxs, + offsets_t=offsets, + block_offsets_t=block_offsets_t, + ada_block=ada_block_size, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, + gates=gates, + ) + output = output.view(gates.size(0), gates.size(1), output.size(-1)).sum( + 1 + ) # this can be moved into kernel? + return output + + +parallel_linear = ParalleLinear.apply + + +class ParalleLinear_optim(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx, + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates=None, + ): + + output = scatter2scatter_sparse_optimized( + X=x, + base_act=base_act, + ada_weights=ada_weights, + row_idxs=row_idxs, + col_idxs_t=col_idxs, + offsets_t=offsets, + block_offsets_t=block_offsets_t, + ada_block=ada_block_size, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + k=k, + gates=gates, + ) + output = output.view(gates.size(0), gates.size(1), output.size(-1)).sum( + 1 + ) # this can be moved into kernel? + return output + + +parallel_linear_optimized = ParalleLinear_optim.apply diff --git a/mttl/models/modifiers/sparsity/spb_moe/ops.py b/mttl/models/modifiers/sparsity/spb_moe/ops.py new file mode 100644 index 000000000..de400acba --- /dev/null +++ b/mttl/models/modifiers/sparsity/spb_moe/ops.py @@ -0,0 +1,176 @@ +import torch +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.spb_moe import functions + + +def sdd_adamerge(a, b, out_topo: Matrix, out_adaps: Matrix, layout): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(out_topo, Matrix) + assert out_topo.is_contiguous() + assert isinstance(out_adaps, Matrix) + assert out_adaps.data.is_contiguous() + assert isinstance(layout, torch.Tensor) + assert layout.is_contiguous() + # essentially merged the adapters into a single Matrix() + assert ( + out_adaps.shape[1] == out_topo.shape[1] + ), "This performs sparse SDD of a and b, the output topo should have the same number of columns as the out_adaps" + assert ( + out_adaps.shape[1] % b.size(1) == 0 + ), "The number of columns in out_adaps should be a multiple of the number of columns in b" + + out = functions.sdd_spsmerge( + a, + b, + out_topo.size(), + out_topo.data, + out_topo.row_indices, + out_topo.column_indices, + out_topo.column_indices_t, + out_topo.block_offsets_t, + out_adaps.data, + layout, + ) + return Matrix( + out_topo.size(), + out, + out_topo.row_indices, + out_topo.column_indices, + out_topo.offsets, + out_topo.column_indices_t, + out_topo.offsets_t, + out_topo.block_offsets_t, + ) + + +def scattergather_adamerge( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, +): + + out = functions.parallel_linear( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, + ) + return out + + +def scattergather_adamerge_opt( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, +): + + out = functions.parallel_linear_optimized( + x, + base_act, + k, + ada_weights, + row_idxs, + col_idxs, + offsets, + block_offsets_t, + ada_block_size, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + gates, + ) + return out + + +BLOCK_M = 128 # expert token capacity + + +@torch.jit.script +def flatten_and_sort(expert_idxs: torch.Tensor): + """ + Flattens a tensor of expert indices and sorts the flattened tensor. + + Args: + expert_idxs (torch.Tensor): A tensor containing expert indices. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - sorted_expert_idxs: The flattened and sorted expert indices. + - sorted_scattered_idxs: The indices that would sort the flattened tensor. + """ + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + return sorted_expert_idxs, sorted_scattered_idxs + + +@torch.jit.script +def padded_block_indices( + sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int = BLOCK_M +): + """ + Compute padded block indices for sorted experts. + + This function calculates the indices of padded blocks for a given set of sorted expert indices. + It ensures that the blocks are padded to a specified block size. + + Args: + sorted_experts_idxs (torch.Tensor): A tensor containing the sorted indices of experts. + k (int): The number of unique experts. + N_BLOCK_SIZE (int, optional): The size of each block. Defaults to BLOCK_M. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - expanded_block_idxs (torch.Tensor): The indices of the expanded blocks. + - expert_boundaries_end (torch.Tensor): The end boundaries of the experts. + """ + expert_counts = torch.bincount(sorted_experts_idxs, minlength=k) + padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1 + padded_expert_block_end = padded_block_counts.cumsum(-1) + expert_boundaries_end = expert_counts.cumsum(-1) + expert_boundaries_start = expert_boundaries_end - expert_counts + padded_expert_block_start = padded_expert_block_end - padded_block_counts + block_idxs = torch.arange( + padded_expert_block_end[-1], + dtype=sorted_experts_idxs.dtype, + device=sorted_experts_idxs.device, + ) + block_mask = (block_idxs[:, None] < padded_expert_block_start) | ( + block_idxs[:, None] >= padded_expert_block_end + ) + expanded_block_idxs = ( + N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) + + expert_boundaries_start + ) + expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1) + return expanded_block_idxs, expert_boundaries_end diff --git a/mttl/models/modifiers/sparsity/spb_moe/triton_kernels.py b/mttl/models/modifiers/sparsity/spb_moe/triton_kernels.py new file mode 100644 index 000000000..d757f00a5 --- /dev/null +++ b/mttl/models/modifiers/sparsity/spb_moe/triton_kernels.py @@ -0,0 +1,702 @@ +import torch +import triton +import triton.language as tl +from torch.nn import functional as F + +BLOCK_M = 128 + + +def _scatter2scatter_configs(): + return [ + triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=1, num_warps=1), + ] + + +@triton.autotune( + configs=_scatter2scatter_configs(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + "ADA_BLCKS_PER_TILE_K": lambda args: args["BLOCK_K"] // args["ADA_BLOCK"], + "ADA_BLCKS_PER_TILE_N": lambda args: args["BLOCK_N"] // args["ADA_BLOCK"], + } +) +@triton.jit +def _scatter2scatter( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_wk, + stride_wn, + adaW, # n_exp x ada_block x ada_block + ada_layout, + stride_layout_e, + stride_layout_m, + stride_layout_n, + Y_ptr, + stride_ym, + stride_yn, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + ADA_BLOCK: tl.constexpr, + ADA_BLCKS_PER_TILE_K: tl.constexpr, # how many ada blocks in one tile in K direction + ADA_BLCKS_PER_TILE_N: tl.constexpr, # how many ada blocks in one tile in N direction +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv( + N, BLOCK_N + ) # is 2? numbe of blocks per expert's output dimension + M_block_id = ( + pid // N_BLOCK_COUNT + ) # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? + # Determine the block indices along the M and N dimensions for this program. + + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + + L_BLOCK_K = tl.arange(0, ADA_BLCKS_PER_TILE_K) + L_BLOCK_N = tl.arange(0, ADA_BLCKS_PER_TILE_N) + additive_idx_blocks = (tl.arange(0, ADA_BLOCK))[:, None] * ADA_BLOCK + ( + tl.arange(0, ADA_BLOCK) + )[None, :] + L_blck_ptrs = ( + ada_layout + + L_BLOCK_K[:, None] * stride_layout_m + + L_BLOCK_N[None, :] * stride_layout_n + + N_block_id * ADA_BLCKS_PER_TILE_N + + E_idx * stride_layout_e + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(K, BLOCK_K) + for K_block_id in range(0, iters): + if NO_K_MASK: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + if NO_N_MASK or K_block_id < (iters - 1): + w = tl.load(W_blk_ptrs) + else: + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + layout_tile = tl.load(L_blck_ptrs) # 2 x 8 + # BETTER TO RESAHPE MEMORY ADDRESSES, NOT THE LOADED DATA? + mask = layout_tile >= 0 + base_addresses = adaW + (layout_tile * (ADA_BLOCK * ADA_BLOCK)) + full_addresses = ( + base_addresses[:, None, :, None] + additive_idx_blocks[None, :, None, :] + ) + full_addresses = full_addresses.reshape( + ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK + ) + mask = mask[:, None, :, None] * ( + tl.zeros((1, ADA_BLOCK, 1, ADA_BLOCK), dtype=ACC_TYPE) + 1.0 + ) + mask = ( + mask.reshape( + ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK + ) + > 0.0 + ) + + adaW_tile = tl.load( + full_addresses, + mask=mask, + other=0.0, + ) + w = ( + w + adaW_tile + ) # .reshape(ADA_BLCKS_PER_TILE_K * ADA_BLOCK, ADA_BLCKS_PER_TILE_N * ADA_BLOCK) + L_blck_ptrs += ADA_BLCKS_PER_TILE_K * stride_layout_m + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter( + X, + W, + ada_weights, + ada_block, + ada_layout, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + x_grouped=False, + y_grouped=False, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + # O = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = ( + padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + assert _scatter2scatter_configs()[0].kwargs["BLOCK_N"] % ada_block == 0 + assert _scatter2scatter_configs()[0].kwargs["BLOCK_K"] % ada_block == 0 + assert (ada_layout.size(1) * ada_block) % W.size(-1) == 0 + + M, K = X.size() + N = y_dim + E = (ada_layout.size(1) * ada_block) // W.size(-1) + ada_layout_stride_e = N // ada_block + # sorted_expert_idxs = sorted_expert_idxs.to(torch.int32) + # sorted_scattered_idxs = sorted_scattered_idxs.to(torch.int32) + # padded_block_idxs = padded_block_idxs.to(torch.int32) + + # with torch.cuda.device(X.device): + _scatter2scatter[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + ada_weights, # n_exp x ada_block x ada_block + ada_layout, + ada_layout_stride_e, + ada_layout.stride(0), + ada_layout.stride(1), + O, + O.stride(0), + O.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + K=K, + N=N, + E=E, + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=True, + x_grouped=x_grouped, + y_grouped=y_grouped, + ADA_BLOCK=ada_block, + ) + return O + + +def _scatter2scatter_sp_configs(): + return [ + # triton.Config({"BLOCK_K": 128}, num_stages=4, num_warps=4), + ] + + +@triton.autotune( + configs=_scatter2scatter_sp_configs(), + key=["M", "N"], +) +@triton.jit +def _scatter2scatter_sp( + X_ptr, + stride_xm, + stride_xk, + gates, + adaW, # n_exp x ada_block x ada_block + adaW_stride_e, + adaW_stride_m, + adaW_stride_n, + base_acts, + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t_offset, + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t_offset, + block_offsets_t, # indices of blocks sorted by column + block_offsets_t_offset, + Y_ptr, + stride_ym, + stride_yn, + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT, + M, + N, + E, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv( + N, BLOCK_N + ) # is 2? numbe of blocks per expert's output dimension + M_block_id = ( + pid // N_BLOCK_COUNT + ) # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid % N_BLOCK_COUNT # which block in the out. dim are we in? + # Determine the block indices along the M and N dimensions for this program. + + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + M_in_idx = M_idx // FAN_OUT + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) + end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) + num_blocks_column = end_inx - start_inx + iters = num_blocks_column # tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + gate = tl.load(gates + M_idx, mask=E_mask) + + if iters > 0: + # pointers to dense matrix + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) # ...16 + rbk = tl.arange(0, BLOCK_K) # ... 16 + W_blk_ptr = ( + adaW + + (rbk[:, None] * adaW_stride_m) + + (rn[None, :] * adaW_stride_n) + + (E_idx * adaW_stride_e) + ) + BLOCK_SIZE = BLOCK_K * BLOCK_N + ak_block_incr = stride_xk * BLOCK_K + + # OW_block_ptr = OW + (rbk[:, None] * adaW_stride_m) + (rn[None, :] * adaW_stride_n) + (E_idx * adaW_stride_e) + + for K_block_id in range(0, iters): + X = ( + X_blk_ptr + + tl.load( + column_indices_t + + (E_idx * column_indices_t_offset) + + start_inx + + K_block_id + ) + * ak_block_incr + ) + + W = ( + W_blk_ptr + + tl.load( + block_offsets_t + + (E_idx * block_offsets_t_offset) + + start_inx + + K_block_id + ) + * BLOCK_SIZE + ) + # OWW = OW_block_ptr + tl.load(block_offsets_t + (E_idx * block_offsets_t_offset) + start_inx + K_block_id) * BLOCK_SIZE + + x = tl.load(X, mask=E_mask[:, None]) + w = tl.load(W, mask=N_mask[None, :]) + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + # tl.store(OWW, w) + + base_act_ptr = ( + base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + ) + base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) + acc *= gate[:, None] + acc += base_act + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter_sparse( + X, + base_act, + ada_weights, + row_idxs, + col_idxs_t, + ada_block, + offsets_t, + block_offsets_t, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + gates, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + assert X.is_contiguous() + assert base_act.is_contiguous() + assert ada_weights.is_contiguous() + assert row_idxs.is_contiguous() + assert col_idxs_t.is_contiguous() + assert offsets_t.is_contiguous() + assert block_offsets_t.is_contiguous() + assert sorted_expert_idxs.is_contiguous() + assert sorted_scattered_idxs.is_contiguous() + assert padded_block_idxs.is_contiguous() + assert gates.is_contiguous() + + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = base_act.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + # O = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + # OW = torch.empty_like(ada_weights, device=X.device, dtype=ada_weights.dtype) + def grid(META): + grid_num = ( + padded_block_idxs.size(0) * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + M, K = X.size() + N = y_dim + E = ada_weights.size(0) + with torch.cuda.device(X.device): + _scatter2scatter_sp[grid]( + X, + X.stride(0), + X.stride(1), + gates, + ada_weights, # n_exp x ada_block x ada_block + ada_weights.stride(0), + ada_weights.stride(2), + ada_weights.stride(3), + base_act, + col_idxs_t, + col_idxs_t.stride(0), + offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t.stride(0), + block_offsets_t, + block_offsets_t.stride(0), + O, + O.stride(0), + O.stride(1), + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + N=N, + E=E, + BLOCK_M=BLOCK_M, + BLOCK_K=ada_block, + BLOCK_N=ada_block, + ACC_TYPE=tl.float32, + ) + return O + + +@triton.autotune( + configs=[ + triton.Config({"GROUP_M": 1, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 128}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 1, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 64}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 1, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 4, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 32, "BLOCK_M": 256}, num_stages=4, num_warps=4), + triton.Config({"GROUP_M": 128, "BLOCK_M": 256}, num_stages=4, num_warps=4), + ], + key=["M", "N", "E"], +) +@triton.jit +def _scatter2scatter_sp_optimized( + X_ptr, + stride_xm, + stride_xk, + gates, + adaW, # n_exp x ada_block x ada_block + adaW_stride_e, + adaW_stride_m, + adaW_stride_n, + base_acts, + column_indices_t, # gives the row index column by column (row indexes sorted by column starting witht he first one etc.) + column_indices_t_offset, + offsets_t, # offsets for columns: i.e. the diff between two consecutive gives the number of blocks per column. It indexes column_indices_t, block_offsets_t + offsets_t_offset, + block_offsets_t, # indices of blocks sorted by column + block_offsets_t_offset, + Y_ptr, + stride_ym, + stride_yn, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT: tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_M: tl.constexpr, + ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, + MAX_K_ITERS: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + M_block_id = pid_m # which expert are we in? (actually block, since there might be multiple blocks per expert) + N_block_id = pid_n # which block in the out. dim are we in? + M_range = tl.arange(0, BLOCK_M) + block_start_idx = tl.load( + padded_block_idxs + M_block_id + ) # Load the index of the starting token for this block + # M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M) + M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M) # max tokens + E_idxs = tl.load( + sorted_expert_idxs + M_block, mask=M_block < (FAN_OUT * M), other=E + ) # expert_idxs_ptr is sorted by expert! so this loads expert indices of tokens + E_idx = tl.min(E_idxs) + E_mask = E_idxs == E_idx + M_idx = tl.load(sorted_scattered_idxs + M_block, mask=E_mask, other=0) + M_in_idx = M_idx // FAN_OUT + M_out_idx = M_idx + + K_block = tl.arange(0, BLOCK_K) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + start_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id) + end_inx = tl.load(offsets_t + (E_idx * offsets_t_offset) + N_block_id + 1) + num_blocks_column = end_inx - start_inx + iters = num_blocks_column # tl.cdiv(num_blocks_column, tl.cdiv(BLOCK_K, ADA_BLOCK)) # n_blocks_column + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + gate = tl.load(gates + M_idx, mask=E_mask) + + # pointers to dense matrix + X_blk_ptr = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + W_blk_ptr = ( + adaW + + (rbk[:, None] * adaW_stride_m) + + (rn[None, :] * adaW_stride_n) + + (E_idx * adaW_stride_e) + ) + BLOCK_SIZE = BLOCK_K * BLOCK_N + ak_block_incr = stride_xk * BLOCK_K + + for K_block_id in tl.range(0, MAX_K_ITERS): + valid = K_block_id < iters + X = ( + X_blk_ptr + + tl.load( + column_indices_t + + (E_idx * column_indices_t_offset) + + start_inx + + K_block_id, + mask=valid, + other=0, + ) + * ak_block_incr + ) + + W = ( + W_blk_ptr + + tl.load( + block_offsets_t + + (E_idx * block_offsets_t_offset) + + start_inx + + K_block_id, + mask=valid, + other=0, + ) + * BLOCK_SIZE + ) + + x = tl.load(X, mask=valid & E_mask[:, None], other=0.0) + w = tl.load(W, mask=valid & N_mask[None, :], other=0.0) + acc += tl.dot(x, w, out_dtype=ACC_TYPE) + + base_act_ptr = ( + base_acts + M_in_idx[:, None] * stride_ym + N_block[None, :] * stride_yn + ) + base_act = tl.load(base_act_ptr, mask=E_mask[:, None] & N_mask[None, :]) + acc *= gate[:, None] + acc += base_act + + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :]) + # tl.atomic_add(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :], scope="cta") # <- could be used to fuse the merging op into the kernel, but it snot working for soem reason + + +def scatter2scatter_sparse_optimized( + X, + base_act, + ada_weights, + row_idxs, + col_idxs_t, + ada_block, + offsets_t, + block_offsets_t, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + padded_block_idxs, + gates, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + assert X.is_contiguous() + assert base_act.is_contiguous() + assert ada_weights.is_contiguous() + assert row_idxs.is_contiguous() + assert col_idxs_t.is_contiguous() + assert offsets_t.is_contiguous() + assert block_offsets_t.is_contiguous() + assert sorted_expert_idxs.is_contiguous() + assert sorted_scattered_idxs.is_contiguous() + assert padded_block_idxs.is_contiguous() + assert gates.is_contiguous() + + # Pre-kernel setup + x_dim = X.size(-1) + y_dim = base_act.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + O = torch.zeros((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + O = out + + def grid(META): + grid_num = ( + padded_block_idxs.size(0), + triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + M, K = X.size() + N = y_dim + E = ada_weights.size(0) + MAX_ITERS = (K + ada_block - 1) // ada_block + with torch.cuda.device(X.device): + _scatter2scatter_sp_optimized[grid]( + X, + X.stride(0), + X.stride(1), + gates, + ada_weights, # n_exp x ada_block x ada_block + ada_weights.stride(0), + ada_weights.stride(2), + ada_weights.stride(3), + base_act, + col_idxs_t, + col_idxs_t.stride(0), + offsets_t, # column offsets shapre is (E, N//ada_block + 1) + offsets_t.stride(0), + block_offsets_t, + block_offsets_t.stride(0), + O, + O.stride(0), + O.stride(1), + # OW, + sorted_scattered_idxs, + sorted_expert_idxs, + padded_block_idxs, + FAN_OUT=k, + M=M, + N=N, + E=E, + BLOCK_K=ada_block, + BLOCK_N=ada_block, + ACC_TYPE=tl.float32, + MAX_K_ITERS=MAX_ITERS, + ) + return O diff --git a/requirements.txt b/requirements.txt index c9b4fa7fc..c33ad6f5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,5 @@ azure-storage-blob azure-identity einops nltk +stanford-stk # spops @ git+https://github.com/IST-DASLab/spops.git@main diff --git a/tests/test_bsr_moe.py b/tests/test_bsr_moe.py new file mode 100644 index 000000000..bf0d46997 --- /dev/null +++ b/tests/test_bsr_moe.py @@ -0,0 +1,74 @@ +import pytest +import torch +from pytorch_lightning import seed_everything +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops +from mttl.models.modifiers.sparsity.spb_moe import ops +from mttl.models.modifiers.sparsity.spb_moe.benchmark import dumb_forward + +blocksize = 16 + +SC_MOE_TEST = [ + (4, 32, 64, 10, 2, 0.8, 16, torch.float32), + (1024, 1024, 8192, 20, 2, 0.8, 16, torch.float32), + (1024, 1024, 2048, 20, 2, 0.8, 16, torch.float32), + (8, 128, 256, 10, 2, 0.8, 16, torch.float32), +] + + +@pytest.mark.skipif( + torch.cuda.is_available() is False, reason="CUDA must be available for this test." +) +@pytest.mark.parametrize("bs, d, h, E, k, sparsity, blocking, dtype", SC_MOE_TEST) +def testScatteredMoE(bs, d, h, E, k, sparsity, blocking, dtype): + seed_everything(42) + device = "cuda" + # print(f"Running test with bs={bs}, d={d}, h={h}, E={E}, k={k}, sparsity={sparsity}, blocking={blocking}, dtype={dtype}") + logits = torch.randn(bs, E, dtype=dtype) + weights = torch.softmax(logits.float(), axis=-1).to(dtype).to(device) + X = torch.randn(bs, d, dtype=dtype, requires_grad=True).to(device) + W = torch.randn(d, h, dtype=dtype, requires_grad=True).to(device) + adaps = [ + matrix_ops._dense_and_sparse(d, h, sparsity, blocking, dtype) for _ in range(E) + ] + adaps_sparse = [adap[1] for adap in adaps] + adaps_dense = [adap[0] for adap in adaps] + ada_data = torch.stack([adap.data for adap in adaps_sparse], dim=0) + row_idxs = torch.stack([adap.row_indices for adap in adaps_sparse], dim=0) + col_idxs_t = torch.stack([adap.column_indices_t for adap in adaps_sparse], dim=0) + offsets_t = torch.stack([adap.offsets_t for adap in adaps_sparse], dim=0) + block_offsets_t = torch.stack( + [adap.block_offsets_t for adap in adaps_sparse], dim=0 + ) + + k_weights, expert_idxs = torch.topk(weights, k) + sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(expert_idxs) + padded_block_idxs, expert_offsets = ops.padded_block_indices(sorted_expert_idxs, E) + + base_act = torch.matmul(X, W) + out2 = ops.scattergather_adamerge_opt( + x=X, + base_act=base_act, + k=k, + ada_weights=ada_data, + row_idxs=row_idxs, + col_idxs=col_idxs_t, + offsets=offsets_t, + block_offsets_t=block_offsets_t, + ada_block_size=blocking, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + gates=k_weights, + ) + + out_dumb = dumb_forward(base_act, X, k_weights, expert_idxs, adaps_dense) + err_Y = torch.abs(out2 - out_dumb) + tolerance = 1e-2 + print(err_Y.max()) + assert err_Y.max() < tolerance, "Y error too large: max %0.05f" % err_Y.max() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_sparse_masks.py b/tests/test_sparse_masks.py index ab63939ea..75fc0c866 100644 --- a/tests/test_sparse_masks.py +++ b/tests/test_sparse_masks.py @@ -14,7 +14,7 @@ ScatteredSparseAdapter, ScatteredSparseLinearModule, ) -from mttl.models.modifiers.sparse_utils.sparse_linear import ScatteredSparseLinearModule +from mttl.models.modifiers.sparsity.sparse_linear import ScatteredSparseLinearModule def test_sm_adapter(): diff --git a/tests/test_stk_matrix.py b/tests/test_stk_matrix.py new file mode 100644 index 000000000..792b071fe --- /dev/null +++ b/tests/test_stk_matrix.py @@ -0,0 +1,34 @@ +import pytest +import stk +import stk.ops +import torch +from stk.matrix import Matrix + +from mttl.models.modifiers.sparsity.sparse_utils import stk_matrix_utils as matrix_ops + + +@pytest.mark.skipif( + torch.cuda.is_available() is False, reason="CUDA must be available for this test." +) +@pytest.mark.parametrize( + "K, rows, cols, sparsity, blocking", + [ + (2, 8, 16, 0.5, 1), + (2, 8, 16, 0.5, 4), + ], +) +def test_layout_creation(K, rows, cols, sparsity, blocking): + adaps = [ + matrix_ops._dense_and_sparse(rows, cols, sparsity, blocking, torch.float16) + for _ in range(K) + ] + adaps_sparse = [adap[1] for adap in adaps] + # adaps_dense = [adap[0] for adap in adaps] + + merged_adaps_matrix: Matrix = matrix_ops.merge_adapters(adaps_sparse) + layout = matrix_ops.create_ada_layout(merged_adaps_matrix) + assert layout.max() == merged_adaps_matrix.data.size(0) - 1 + + +if __name__ == "__main__": + pytest.main([__file__])