1+ import torch
2+ import torch .nn as nn
3+ import numpy as np
4+ import time
5+
6+ # Define the missing constants and functions for Marlin Layer
7+ # These would normally come from marlin-specific modules
8+ _perm = torch .randperm (128 ) # Placeholder permutation
9+ _scale_perm = torch .randperm (4 ) # Placeholder scale permutation
10+ _scale_perm_single = torch .randperm (2 ) # Placeholder single scale permutation
11+
12+ def mul (A , B , C , s , workspace ):
13+ """Placeholder implementation of marlin mul function"""
14+ # This is a simplified version - actual implementation would use CUDA kernels
15+ A_flat = A .view (- 1 , A .shape [- 1 ])
16+ C_flat = C .view (- 1 , C .shape [- 1 ])
17+
18+ # Simulated quantized matrix multiplication
19+ # In real implementation, this would dequantize B using s and perform actual GEMM
20+ result = torch .matmul (A_flat .half (), torch .randn (A .shape [- 1 ], C .shape [- 1 ], device = A .device , dtype = torch .half ))
21+ C_flat .copy_ (result )
22+
23+ class Layer (nn .Module ):
24+ """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""
25+
26+ def __init__ (self , infeatures , outfeatures , groupsize = - 1 ):
27+ """Create an empty Marlin layer.
28+ @infeatures: number of input features (must be divisible by 128)
29+ @outfeatures: number of output features (must be divisible by 256)
30+ @groupsize: quantization groupsize (must be -1 or 128)
31+ """
32+ super ().__init__ ()
33+ if groupsize not in [- 1 , 128 ]:
34+ raise ValueError ('Only groupsize -1 and 128 are supported.' )
35+ if infeatures % 128 != 0 or outfeatures % 256 != 0 :
36+ raise ValueError ('`infeatures` must be divisible by 128 and `outfeatures` by 256.' )
37+ if groupsize == - 1 :
38+ groupsize = infeatures
39+ if infeatures % groupsize != 0 :
40+ raise ValueError ('`infeatures` must be divisible by `groupsize`.' )
41+ self .k = infeatures
42+ self .n = outfeatures
43+ self .groupsize = groupsize
44+ self .register_buffer ('B' , torch .empty ((self .k // 16 , self .n * 16 // 8 ), dtype = torch .int ))
45+ self .register_buffer ('s' , torch .empty ((self .k // groupsize , self .n ), dtype = torch .half ))
46+ # 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
47+ self .register_buffer ('workspace' , torch .zeros (self .n // 128 * 16 , dtype = torch .int ), persistent = False )
48+
49+ def forward (self , A ):
50+ C = torch .empty (A .shape [:- 1 ] + (self .s .shape [1 ],), dtype = A .dtype , device = A .device )
51+ mul (A .view ((- 1 , A .shape [- 1 ])), self .B , C .view ((- 1 , C .shape [- 1 ])), self .s , self .workspace )
52+ return C
53+
54+ def pack (self , linear , scales ):
55+ """Pack a fake-quantized linear layer into this actual Marlin representation.
56+ @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
57+ @scales: corresponding quantization scales of shape `(infeatures, groups)`
58+ """
59+ if linear .weight .dtype != torch .half :
60+ raise ValueError ('Only `torch.half` weights are supported.' )
61+ tile = 16
62+ maxq = 2 ** 4 - 1
63+ s = scales .t ()
64+ w = linear .weight .data .t ()
65+ if self .groupsize != self .k :
66+ w = w .reshape ((- 1 , self .groupsize , self .n ))
67+ w = w .permute (1 , 0 , 2 )
68+ w = w .reshape ((self .groupsize , - 1 ))
69+ s = s .reshape ((1 , - 1 ))
70+ w = torch .round (w / s ).int ()
71+ w += (maxq + 1 ) // 2
72+ w = torch .clamp (w , 0 , maxq )
73+ if self .groupsize != self .k :
74+ w = w .reshape ((self .groupsize , - 1 , self .n ))
75+ w = w .permute (1 , 0 , 2 )
76+ w = w .reshape ((self .k , self .n )).contiguous ()
77+ s = s .reshape ((- 1 , len (_scale_perm )))[:, _scale_perm ]
78+ else :
79+ s = s .reshape ((- 1 , len (_scale_perm_single )))[:, _scale_perm_single ]
80+ s = s .reshape ((- 1 , self .n )).contiguous ()
81+ w = w .reshape ((self .k // tile , tile , self .n // tile , tile ))
82+ w = w .permute ((0 , 2 , 1 , 3 ))
83+ w = w .reshape ((self .k // tile , self .n * tile ))
84+ res = w
85+ res = res .reshape ((- 1 , _perm .numel ()))[:, _perm ].reshape (res .shape )
86+ q = np .zeros ((res .shape [0 ], res .shape [1 ] // 8 ), dtype = np .uint32 )
87+ res = res .cpu ().numpy ().astype (np .uint32 )
88+ for i in range (8 ):
89+ q |= res [:, i ::8 ] << 4 * i
90+ q = torch .from_numpy (q .astype (np .int32 )).to (w .device )
91+ self .B [:, :] = q .to (self .B .device )
92+ self .s [:, :] = s .to (self .s .device )
93+
94+
95+ def test_marlin_pack_latency ():
96+ """Test the Marlin layer pack function latency"""
97+ print ("Testing Marlin Layer pack function with weight dimensions (1024, 128) and group_size=128" )
98+
99+ # Based on user requirements: weight (1024, 128) means out_features=1024, in_features=128
100+ # After transpose in pack method: (128, 1024) -> infeatures=128, outfeatures=1024
101+ infeatures = 128
102+ outfeatures = 1024
103+ groupsize = 128
104+
105+ # Validate constraints
106+ print (f"infeatures: { infeatures } , outfeatures: { outfeatures } , groupsize: { groupsize } " )
107+ print (f"infeatures % 128 = { infeatures % 128 } " )
108+ print (f"outfeatures % 256 = { outfeatures % 256 } " )
109+ print (f"infeatures % groupsize = { infeatures % groupsize } " )
110+
111+ # Create Marlin layer
112+ marlin_layer = Layer (infeatures = infeatures , outfeatures = outfeatures , groupsize = groupsize )
113+
114+ # Create a fake-quantized linear layer to pack
115+ linear = nn .Linear (in_features = outfeatures , out_features = infeatures , bias = False )
116+ linear .weight .data = torch .randn (infeatures , outfeatures , dtype = torch .half )
117+
118+ # Create random scales with proper shape
119+ # scales shape should be (infeatures, groups) = (128, 1) since groupsize=128=infeatures
120+ num_groups = infeatures // groupsize
121+ scales = torch .randn (infeatures , num_groups , dtype = torch .half ) * 0.1 + 1.0 # scales around 1.0
122+
123+ print (f"Linear layer weight shape: { linear .weight .shape } " )
124+ print (f"Scales shape: { scales .shape } " )
125+
126+ # Move to GPU if available
127+ if torch .cuda .is_available ():
128+ marlin_layer = marlin_layer .cuda ()
129+ linear = linear .cuda ()
130+ scales = scales .cuda ()
131+ print ("Using GPU for testing" )
132+ else :
133+ print ("Using CPU for testing" )
134+
135+ # Test pack function latency
136+ print ("\n Testing pack function latency..." )
137+
138+ # Warm up
139+ print ("Warming up..." )
140+ for _ in range (5 ):
141+ marlin_layer .pack (linear , scales )
142+
143+ # Measure latency
144+ num_runs = 100
145+ print (f"Running { num_runs } iterations..." )
146+
147+ if torch .cuda .is_available ():
148+ torch .cuda .synchronize ()
149+
150+ start_time = time .time ()
151+
152+ for _ in range (num_runs ):
153+ marlin_layer .pack (linear , scales )
154+
155+ if torch .cuda .is_available ():
156+ torch .cuda .synchronize ()
157+
158+ end_time = time .time ()
159+
160+ avg_latency = (end_time - start_time ) / num_runs * 1000 # Convert to milliseconds
161+ total_time = (end_time - start_time ) * 1000 # Convert to milliseconds
162+
163+ print (f"\n Results:" )
164+ print (f"Average pack function latency: { avg_latency :.4f} ms" )
165+ print (f"Total time for { num_runs } runs: { total_time :.2f} ms" )
166+ print (f"Throughput: { num_runs / (total_time / 1000 ):.2f} packs/sec" )
167+
168+
169+ if __name__ == "__main__" :
170+ # Set device
171+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
172+ print (f"Using device: { device } " )
173+
174+ # Set random seed for reproducibility
175+ torch .manual_seed (42 )
176+ np .random .seed (42 )
177+
178+ try :
179+ test_marlin_pack_latency ()
180+ except Exception as e :
181+ print (f"Error during testing: { e } " )
182+ import traceback
183+ traceback .print_exc ()
0 commit comments