Skip to content

Commit 23cad7c

Browse files
committed
update ablation
1 parent 1f1f649 commit 23cad7c

7 files changed

Lines changed: 417 additions & 3 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 python test_bitblas.py
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import bitblas
2+
import torch
3+
import time
4+
import numpy as np
5+
6+
# uncomment to enable debug output
7+
# bitblas.set_log_level("Debug")
8+
9+
# Prefill
10+
n_heads = 1
11+
seq_len = 128
12+
dim = 128
13+
matmul_config = bitblas.MatmulConfig(
14+
M=1, # M dimension
15+
N=n_heads*seq_len, # N dimension
16+
K=dim, # K dimension
17+
A_dtype="float16", # activation A dtype
18+
W_dtype="int4", # weight W dtype
19+
accum_dtype="float16", # accumulation dtype
20+
out_dtype="float16", # output dtype
21+
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
22+
with_bias=False, # bias
23+
# configs for weight only quantization
24+
group_size=None, # setting for grouped quantization
25+
with_scaling=False, # setting for scaling factor
26+
with_zeros=False, # setting for zeros
27+
zeros_mode=None, # setting for how to calculating zeros
28+
)
29+
30+
matmul = bitblas.Matmul(config=matmul_config)
31+
32+
# Create input matrices
33+
# input_tensor = torch.rand((1, dim), dtype=torch.float16).cuda()
34+
weight_tensor = torch.randint(0, 7, (n_heads*seq_len, dim), dtype=torch.int8).cuda()
35+
36+
# Warmup runs
37+
print("\nWarming up...")
38+
for _ in range(5):
39+
_ = matmul.transform_weight(weight_tensor)
40+
torch.cuda.synchronize()
41+
42+
# Timing runs
43+
num_runs = 10
44+
times = []
45+
46+
print(f"\nRunning {num_runs} timing iterations...")
47+
48+
for i in range(num_runs):
49+
torch.cuda.synchronize()
50+
start_time = time.perf_counter()
51+
52+
weight_tensor_int4 = matmul.transform_weight(weight_tensor)
53+
54+
torch.cuda.synchronize()
55+
end_time = time.perf_counter()
56+
57+
elapsed_time = (end_time - start_time) * 1000 # Convert to milliseconds
58+
times.append(elapsed_time)
59+
60+
if (i + 1) % 20 == 0:
61+
print(f" Completed {i + 1}/{num_runs} runs")
62+
63+
times = np.array(times)
64+
mean_time = np.mean(times)
65+
66+
print(f"Mean time: {mean_time} ms")

evaluation/ablation/test_marlin.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import time
5+
6+
# Define the missing constants and functions for Marlin Layer
7+
# These would normally come from marlin-specific modules
8+
_perm = torch.randperm(128) # Placeholder permutation
9+
_scale_perm = torch.randperm(4) # Placeholder scale permutation
10+
_scale_perm_single = torch.randperm(2) # Placeholder single scale permutation
11+
12+
def mul(A, B, C, s, workspace):
13+
"""Placeholder implementation of marlin mul function"""
14+
# This is a simplified version - actual implementation would use CUDA kernels
15+
A_flat = A.view(-1, A.shape[-1])
16+
C_flat = C.view(-1, C.shape[-1])
17+
18+
# Simulated quantized matrix multiplication
19+
# In real implementation, this would dequantize B using s and perform actual GEMM
20+
result = torch.matmul(A_flat.half(), torch.randn(A.shape[-1], C.shape[-1], device=A.device, dtype=torch.half))
21+
C_flat.copy_(result)
22+
23+
class Layer(nn.Module):
24+
"""PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""
25+
26+
def __init__(self, infeatures, outfeatures, groupsize=-1):
27+
"""Create an empty Marlin layer.
28+
@infeatures: number of input features (must be divisible by 128)
29+
@outfeatures: number of output features (must be divisible by 256)
30+
@groupsize: quantization groupsize (must be -1 or 128)
31+
"""
32+
super().__init__()
33+
if groupsize not in [-1, 128]:
34+
raise ValueError('Only groupsize -1 and 128 are supported.')
35+
if infeatures % 128 != 0 or outfeatures % 256 != 0:
36+
raise ValueError('`infeatures` must be divisible by 128 and `outfeatures` by 256.')
37+
if groupsize == -1:
38+
groupsize = infeatures
39+
if infeatures % groupsize != 0:
40+
raise ValueError('`infeatures` must be divisible by `groupsize`.')
41+
self.k = infeatures
42+
self.n = outfeatures
43+
self.groupsize = groupsize
44+
self.register_buffer('B', torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int))
45+
self.register_buffer('s', torch.empty((self.k // groupsize, self.n), dtype=torch.half))
46+
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
47+
self.register_buffer('workspace', torch.zeros(self.n // 128 * 16, dtype=torch.int), persistent=False)
48+
49+
def forward(self, A):
50+
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
51+
mul(A.view((-1, A.shape[-1])), self.B, C.view((-1, C.shape[-1])), self.s, self.workspace)
52+
return C
53+
54+
def pack(self, linear, scales):
55+
"""Pack a fake-quantized linear layer into this actual Marlin representation.
56+
@linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
57+
@scales: corresponding quantization scales of shape `(infeatures, groups)`
58+
"""
59+
if linear.weight.dtype != torch.half:
60+
raise ValueError('Only `torch.half` weights are supported.')
61+
tile = 16
62+
maxq = 2 ** 4 - 1
63+
s = scales.t()
64+
w = linear.weight.data.t()
65+
if self.groupsize != self.k:
66+
w = w.reshape((-1, self.groupsize, self.n))
67+
w = w.permute(1, 0, 2)
68+
w = w.reshape((self.groupsize, -1))
69+
s = s.reshape((1, -1))
70+
w = torch.round(w / s).int()
71+
w += (maxq + 1) // 2
72+
w = torch.clamp(w, 0, maxq)
73+
if self.groupsize != self.k:
74+
w = w.reshape((self.groupsize, -1, self.n))
75+
w = w.permute(1, 0, 2)
76+
w = w.reshape((self.k, self.n)).contiguous()
77+
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
78+
else:
79+
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
80+
s = s.reshape((-1, self.n)).contiguous()
81+
w = w.reshape((self.k // tile, tile, self.n // tile, tile))
82+
w = w.permute((0, 2, 1, 3))
83+
w = w.reshape((self.k // tile, self.n * tile))
84+
res = w
85+
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
86+
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
87+
res = res.cpu().numpy().astype(np.uint32)
88+
for i in range(8):
89+
q |= res[:, i::8] << 4 * i
90+
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
91+
self.B[:, :] = q.to(self.B.device)
92+
self.s[:, :] = s.to(self.s.device)
93+
94+
95+
def test_marlin_pack_latency():
96+
"""Test the Marlin layer pack function latency"""
97+
print("Testing Marlin Layer pack function with weight dimensions (1024, 128) and group_size=128")
98+
99+
# Based on user requirements: weight (1024, 128) means out_features=1024, in_features=128
100+
# After transpose in pack method: (128, 1024) -> infeatures=128, outfeatures=1024
101+
infeatures = 128
102+
outfeatures = 1024
103+
groupsize = 128
104+
105+
# Validate constraints
106+
print(f"infeatures: {infeatures}, outfeatures: {outfeatures}, groupsize: {groupsize}")
107+
print(f"infeatures % 128 = {infeatures % 128}")
108+
print(f"outfeatures % 256 = {outfeatures % 256}")
109+
print(f"infeatures % groupsize = {infeatures % groupsize}")
110+
111+
# Create Marlin layer
112+
marlin_layer = Layer(infeatures=infeatures, outfeatures=outfeatures, groupsize=groupsize)
113+
114+
# Create a fake-quantized linear layer to pack
115+
linear = nn.Linear(in_features=outfeatures, out_features=infeatures, bias=False)
116+
linear.weight.data = torch.randn(infeatures, outfeatures, dtype=torch.half)
117+
118+
# Create random scales with proper shape
119+
# scales shape should be (infeatures, groups) = (128, 1) since groupsize=128=infeatures
120+
num_groups = infeatures // groupsize
121+
scales = torch.randn(infeatures, num_groups, dtype=torch.half) * 0.1 + 1.0 # scales around 1.0
122+
123+
print(f"Linear layer weight shape: {linear.weight.shape}")
124+
print(f"Scales shape: {scales.shape}")
125+
126+
# Move to GPU if available
127+
if torch.cuda.is_available():
128+
marlin_layer = marlin_layer.cuda()
129+
linear = linear.cuda()
130+
scales = scales.cuda()
131+
print("Using GPU for testing")
132+
else:
133+
print("Using CPU for testing")
134+
135+
# Test pack function latency
136+
print("\nTesting pack function latency...")
137+
138+
# Warm up
139+
print("Warming up...")
140+
for _ in range(5):
141+
marlin_layer.pack(linear, scales)
142+
143+
# Measure latency
144+
num_runs = 100
145+
print(f"Running {num_runs} iterations...")
146+
147+
if torch.cuda.is_available():
148+
torch.cuda.synchronize()
149+
150+
start_time = time.time()
151+
152+
for _ in range(num_runs):
153+
marlin_layer.pack(linear, scales)
154+
155+
if torch.cuda.is_available():
156+
torch.cuda.synchronize()
157+
158+
end_time = time.time()
159+
160+
avg_latency = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
161+
total_time = (end_time - start_time) * 1000 # Convert to milliseconds
162+
163+
print(f"\nResults:")
164+
print(f"Average pack function latency: {avg_latency:.4f} ms")
165+
print(f"Total time for {num_runs} runs: {total_time:.2f} ms")
166+
print(f"Throughput: {num_runs / (total_time / 1000):.2f} packs/sec")
167+
168+
169+
if __name__ == "__main__":
170+
# Set device
171+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172+
print(f"Using device: {device}")
173+
174+
# Set random seed for reproducibility
175+
torch.manual_seed(42)
176+
np.random.seed(42)
177+
178+
try:
179+
test_marlin_pack_latency()
180+
except Exception as e:
181+
print(f"Error during testing: {e}")
182+
import traceback
183+
traceback.print_exc()

evaluation/bench_throughput.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import argparse
2+
import dataclasses
3+
import time
4+
import numpy as np
5+
import torch
6+
from tqdm.auto import tqdm
7+
from llama import LlamaForCausalLM
8+
from transformers import LlamaConfig, AutoTokenizer
9+
10+
@dataclasses.dataclass
11+
class ModelConfig:
12+
model_path: str
13+
dtype: str = dataclasses.field(default="float16")
14+
# device: str = dataclasses.field(default="cuda:0")
15+
16+
17+
def load_model(args):
18+
# device = torch.device(args.device)
19+
dtype = getattr(torch, args.dtype)
20+
torch.set_default_dtype(dtype)
21+
22+
config = LlamaConfig.from_pretrained(args.model_path)
23+
config.attn_backend = args.attn_backend
24+
config.num_bits = args.num_bits
25+
config.quant_mode = args.quant_mode
26+
config.group_size = args.group_size
27+
config.residual_block_size = 128 if args.num_bits == 4 else 256
28+
29+
model = LlamaForCausalLM.from_pretrained(
30+
args.model_path,
31+
config=config,
32+
device_map="auto",
33+
torch_dtype=dtype
34+
)
35+
return model
36+
37+
@torch.inference_mode()
38+
def benchmark_throughput():
39+
parser = argparse.ArgumentParser()
40+
parser.add_argument("--model_path", default="llama3-8b-instruct")
41+
parser.add_argument("--batch_size", type=int, default=1)
42+
parser.add_argument("--context_len", type=int, default=2*1024)
43+
parser.add_argument("--decode_len", type=int, default=256)
44+
parser.add_argument("--iteration", type=int, default=10)
45+
parser.add_argument("--device", type=str, default="cuda:0")
46+
parser.add_argument("--dtype", type=str, default="float16")
47+
parser.add_argument("--attn_backend", type=str, default="flash_attention_2")
48+
parser.add_argument("--num_bits", type=int, default=4)
49+
parser.add_argument("--quant_mode", type=str, default="k-channel")
50+
parser.add_argument("--group_size", type=int, default=128)
51+
52+
args = parser.parse_args()
53+
54+
model = load_model(args)
55+
56+
context_len = args.context_len
57+
decode_len = args.decode_len
58+
batch_size = args.batch_size
59+
60+
dtype = getattr(torch, args.dtype)
61+
device = torch.device(args.device)
62+
hidden_size = model.config.hidden_size
63+
64+
prefill_latency = []
65+
decode_latency = []
66+
67+
for iter_idx in tqdm(range(args.iteration)):
68+
# clear cuda cache
69+
torch.cuda.empty_cache()
70+
torch.cuda.reset_peak_memory_stats(device)
71+
72+
# Prefill Stage
73+
ts = time.perf_counter()
74+
hidden_states = torch.randn(batch_size, context_len, hidden_size, dtype=dtype, device=device)
75+
out = model(
76+
inputs_embeds=hidden_states,
77+
use_cache=True
78+
)
79+
torch.cuda.synchronize()
80+
te = time.perf_counter()
81+
prefill_latency.append(te - ts)
82+
83+
# Memory stats after prefill
84+
if iter_idx == 0:
85+
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(device) / 1e6:.2f} MB")
86+
print(f"Peak GPU Memory: {torch.cuda.max_memory_allocated(device) / 1e6:.2f} MB")
87+
88+
# Warm up for decode
89+
for _ in range(5):
90+
hidden_states = torch.randn(batch_size, 1, hidden_size, dtype=dtype, device=device)
91+
model(
92+
inputs_embeds=hidden_states,
93+
past_key_values=out.past_key_values,
94+
use_cache=True,
95+
)
96+
97+
# Decode Stage - measure total time for all tokens
98+
ts_decode_total = time.perf_counter()
99+
for _ in range(decode_len):
100+
hidden_states = torch.randn(batch_size, 1, hidden_size, dtype=dtype, device=device)
101+
out = model(
102+
inputs_embeds=hidden_states,
103+
past_key_values=out.past_key_values,
104+
use_cache=True,
105+
)
106+
torch.cuda.synchronize()
107+
te_decode_total = time.perf_counter()
108+
decode_latency.append(te_decode_total - ts_decode_total)
109+
110+
# Calculate metrics
111+
avg_prefill_latency = np.mean(prefill_latency)
112+
avg_decode_latency = np.mean(decode_latency)
113+
# avg_decode_latency -= 0.0019366741180 * 32
114+
115+
# Calculate throughput
116+
prefill_throughput = (batch_size * context_len) / avg_prefill_latency
117+
decode_throughput = (batch_size * decode_len) / avg_decode_latency
118+
119+
# Print results in a table format
120+
print("\n===== BENCHMARK RESULTS =====")
121+
print(f"Model: {args.model_path}")
122+
print(f"Batch Size: {batch_size}")
123+
print(f"Context Length: {context_len}")
124+
print(f"Decode Length: {decode_len}")
125+
print(f"Quantization: {args.num_bits}-bit {args.quant_mode}")
126+
print("\n--- Latency ---")
127+
print(f"Avg Prefill Latency: {avg_prefill_latency:.4f} s")
128+
print(f"Avg Decode Latency (total): {avg_decode_latency:.4f} s")
129+
print(f"Avg Decode Latency (per token): {avg_decode_latency/decode_len:.4f} s")
130+
print("\n--- Throughput ---")
131+
print(f"Prefill Throughput: {prefill_throughput:.2f} tokens/s")
132+
print(f"Decode Throughput: {decode_throughput:.2f} tokens/s")
133+
134+
# CSV format for easy parsing
135+
print("\n--- CSV Format ---")
136+
print("batch_size,context_len,decode_len,prefill_latency,decode_latency,prefill_throughput,decode_throughput")
137+
print(f"{batch_size},{context_len},{decode_len},{avg_prefill_latency:.4f},{avg_decode_latency:.4f},{prefill_throughput:.2f},{decode_throughput:.2f}")
138+
139+
if __name__ == "__main__":
140+
benchmark_throughput()

0 commit comments

Comments
 (0)