@@ -30,9 +30,6 @@ def get_marlin_layer(): ##use an ugly wrapper to import gptqmodel on demand
3030 NEW_VERSION = False
3131 if Version (gptqmodel .__version__ ) >= Version ("5.0.0" ):
3232 NEW_VERSION = True
33- NEW_VERSION_6_0 = False
34- if Version (gptqmodel .__version__ ) >= Version ("6.0.0" ):
35- NEW_VERSION_6_0 = True
3633 from gptqmodel .models ._const import DEVICE , PLATFORM # pylint: disable=E0401
3734 from gptqmodel .nn_modules .qlinear import BaseQuantLinear # pylint: disable=E0401
3835 from gptqmodel .utils .backend import BACKEND # pylint: disable=E0401
@@ -247,59 +244,20 @@ def __init__(
247244 # (since we have only one group per output channel)
248245 desc_act = False
249246
250- backend = kwargs .pop ("backend" , BACKEND .MARLIN )
251- if NEW_VERSION_6_0 :
252- # gptqmodel >= 6.0.0: BaseQuantLinear no longer accepts group_size/sym/desc_act/pack_dtype
253- # directly; they must be passed via validate_kwargs. Attributes are also set manually.
254- super ().__init__ (
255- bits = bits ,
256- in_features = in_features ,
257- out_features = out_features ,
258- bias = bias ,
259- backend = backend ,
260- adapter = None ,
261- register_buffers = False ,
262- validate_kwargs = {
263- "group_size" : group_size ,
264- "desc_act" : desc_act ,
265- "sym" : sym ,
266- "pack_dtype" : pack_dtype ,
267- },
268- ** kwargs ,
269- )
270- # Set attributes that intermediate classes (PackedQuantLinear /
271- # GPTQQuantLinear) would have set in the old API.
272- self .pack_dtype = pack_dtype
273- if pack_dtype == torch .int8 :
274- self .pack_dtype_bits = 8
275- elif pack_dtype == torch .int16 :
276- self .pack_dtype_bits = 16
277- elif pack_dtype == torch .int32 :
278- self .pack_dtype_bits = 32
279- elif pack_dtype == torch .int64 :
280- self .pack_dtype_bits = 64
281- else :
282- raise ValueError (f"Unsupported pack_dtype: { pack_dtype } " )
283- self .pack_factor = self .pack_dtype_bits // bits
284- self .group_size = group_size if group_size != - 1 else in_features
285- self .requested_group_size = group_size
286- self .desc_act = desc_act
287- self .sym = sym
288- else :
289- super ().__init__ (
290- bits = bits ,
291- group_size = group_size ,
292- sym = sym ,
293- desc_act = desc_act ,
294- in_features = in_features ,
295- out_features = out_features ,
296- bias = bias ,
297- pack_dtype = pack_dtype ,
298- backend = backend ,
299- adapter = None ,
300- register_buffers = False ,
301- ** kwargs ,
302- )
247+ super ().__init__ (
248+ bits = bits ,
249+ group_size = group_size ,
250+ sym = sym ,
251+ desc_act = desc_act ,
252+ in_features = in_features ,
253+ out_features = out_features ,
254+ bias = bias ,
255+ pack_dtype = pack_dtype ,
256+ backend = kwargs .pop ("backend" , BACKEND .MARLIN ),
257+ adapter = None ,
258+ register_buffers = False ,
259+ ** kwargs ,
260+ )
303261
304262 # toggle fp32 mode depending on MARLIN or MARLIN_FP16 backend
305263 self .fp32 = True if self .backend in [BACKEND .MARLIN , BACKEND .AUTO ] else False
0 commit comments