88import os
99import re
1010from dataclasses import dataclass , field
11- from typing import Any , cast , Optional , Union
11+ from typing import Any , Optional , Union
1212
1313import torch
1414from aiter import QuantType
15- from aiter .utility .dtypes import d_dtypes
15+ from atom .quant_spec import (
16+ LayerQuantConfig ,
17+ ParsedQuantConfig ,
18+ get_quant_parser ,
19+ )
1620from atom .utils import envs , get_open_port
1721from atom .utils .distributed .utils import stateless_init_torch_distributed_process_group
1822from torch .distributed import ProcessGroup , ReduceOp
@@ -251,70 +255,94 @@ def set_splitting_ops_for_v1(self):
251255 ]
252256
253257
254- class LayerQuantConfig (dict ):
255- def __init__ (
256- self ,
257- quant_type = QuantType .No ,
258- quant_dtype = torch .bfloat16 ,
259- is_dynamic = True ,
260- quant_method = "" ,
261- ):
262- """
263- Core components of layer_quant
264- """
265- super ().__init__ ()
266- self ["quant_type" ] = quant_type if quant_type is not None else QuantType .No
267- self ["quant_dtype" ] = quant_dtype if quant_dtype is not None else torch .bfloat16
268- self ["is_dynamic" ] = is_dynamic
269- self ["quant_method" ] = quant_method
258+ class QuantizationConfig :
259+ """Model-wide quantization configuration.
270260
261+ API:
262+ - ``get_layer_quant_config(prefix)`` -> :class:`LayerQuantConfig`
263+ - ``global_quant_config`` property -> :class:`LayerQuantConfig`
264+ - ``quant_type``, ``quant_dtype``, ``is_dynamic`` convenience properties
265+ """
271266
272- class QuantizationConfig :
273267 def __init__ (self , config : PretrainedConfig = None ):
274268 if config is None :
275269 self .torch_dtype = torch .bfloat16
276270 self .hf_quant_config = None
277- self .global_quant_config = LayerQuantConfig ()
278- self .layer_quant_config = {}
279- self .exclude_layers = []
271+ self ._parsed = ParsedQuantConfig ()
272+ self .exclude_layers : list [str ] = []
280273 self .quant_method = ""
281274 return
282275
283276 self .torch_dtype = getattr (config , "torch_dtype" , torch .bfloat16 )
284277 self .hf_quant_config = getattr (config , "quantization_config" , None )
285- self .global_quant_config = None
286- self .layer_quant_config = {}
287278 self .exclude_layers = []
288279
289280 if self .hf_quant_config is None :
290- self .global_quant_config = LayerQuantConfig (
291- quant_type = QuantType .No , quant_dtype = self .torch_dtype
281+ self ._parsed = ParsedQuantConfig (
282+ global_spec = LayerQuantConfig (
283+ quant_type = QuantType .No , quant_dtype = self .torch_dtype
284+ )
292285 )
293286 self .quant_method = ""
294287 return
295288
296289 self .quant_method = self .hf_quant_config .get ("quant_method" , "" )
297- if self .quant_method == "quark" :
298- layer_quant_config_dict = cast (
299- dict [str , Any ], self .hf_quant_config .get ("layer_quant_config" , {})
300- )
301- for layer_name , layer_cfg in layer_quant_config_dict .items ():
302- self .layer_quant_config [layer_name ] = self .parse_quark_config_dict (
303- layer_cfg
304- )
305290
306- global_quant_config_dict = cast (
307- dict [str , Any ], self .hf_quant_config .get ("global_quant_config" , {})
308- )
309- self .global_quant_config = self .parse_quark_config_dict (
310- global_quant_config_dict
311- )
291+ # Use the parser registry to build a structured ParsedQuantConfig
292+ parser = get_quant_parser (self .quant_method )
293+ self ._parsed = parser .parse (self .hf_quant_config )
294+ self .exclude_layers = list (self ._parsed .exclude_layers )
312295
313- self .exclude_layers = cast (
314- list [str ], self .hf_quant_config .get ("exclude" , [])
315- )
316- else :
317- self .parse_other_config ()
296+ # -- typed API (preferred) ----------------------------------------------
297+
298+ @property
299+ def global_quant_config (self ) -> LayerQuantConfig :
300+ """The default quantization spec for all layers."""
301+ return self ._parsed .global_spec
302+
303+ def get_layer_quant_config (self , layer_name : str ) -> LayerQuantConfig :
304+ """Return the :class:`LayerQuantConfig` for *layer_name*.
305+
306+ Resolution order:
307+ 1. Check exclude list -> ``LayerQuantConfig.no_quant()``.
308+ 2. Exact match in ``parsed.layer_specs``.
309+ 3. fnmatch-style pattern match in ``parsed.layer_pattern_specs``.
310+ 4. Fall back to ``global_quant_config``.
311+ """
312+ # 1. Exclude list
313+ if self ._is_excluded (layer_name ):
314+ return LayerQuantConfig (quant_dtype = self .torch_dtype )
315+
316+ # 2. Exact match
317+ if layer_name in self ._parsed .layer_specs :
318+ return self ._parsed .layer_specs [layer_name ]
319+
320+ # 3. Pattern match
321+ for pattern , spec in self ._parsed .layer_pattern_specs :
322+ if "*" not in pattern :
323+ if layer_name in pattern :
324+ return spec
325+ elif fnmatch .fnmatch (layer_name , pattern ):
326+ return spec
327+
328+ # 4. Global default
329+ return self ._parsed .global_spec
330+
331+ # -- convenience properties (delegate to global_quant_config) -------------
332+
333+ @property
334+ def quant_type (self ) -> QuantType :
335+ return self ._parsed .global_spec .quant_type
336+
337+ @property
338+ def quant_dtype (self ) -> torch .dtype :
339+ return self ._parsed .global_spec .quant_dtype
340+
341+ @property
342+ def is_dynamic (self ) -> bool :
343+ return self ._parsed .global_spec .is_dynamic
344+
345+ # -- other methods ------------------------------------------------------
318346
319347 def compute_hash (self ) -> str :
320348 """
@@ -329,191 +357,53 @@ def compute_hash(self) -> str:
329357 the final hidden states.
330358 """
331359 factors : list [Any ] = []
332- factors .append (self .global_quant_config )
333- factors .append (self .layer_quant_config )
360+ factors .append (self ._parsed . global_spec )
361+ factors .append (self ._parsed . layer_pattern_specs )
334362 factors .append (self .exclude_layers )
335363 hash_value = hashlib .sha256 (str (factors ).encode ()).hexdigest ()
336364 return hash_value
337365
338366 def get_name (self ):
339- """
340- Returns the quantization method name.
341- """
367+ """Returns the quantization method name."""
342368 return self .quant_method
343369
344- def parse_quark_config_dict (self , config : dict ) -> LayerQuantConfig :
345- quant_type = None
346- quant_dtype = None
347- is_dynamic = True
348- weight_config = cast (dict [str , Any ], config .get ("weight" , {}))
349- input_config = cast (dict [str , Any ], config .get ("input_tensors" , {}))
350- weight_qscheme = cast (str , weight_config .get ("qscheme" , "" ))
351- weight_dtype = weight_config .get ("dtype" , "" )
352-
353- # quant_type
354- if weight_qscheme == "per_channel" :
355- quant_type = QuantType .per_Token
356- elif weight_qscheme == "per_tensor" :
357- quant_type = QuantType .per_Tensor
358- elif weight_qscheme == "per_group" :
359- # Currently, quark only supports group_size=32
360- quant_type = QuantType .per_1x32
361- else :
362- quant_type = QuantType .No
363-
364- # quant_dtype
365- dtype = weight_dtype .split ("_" )[0 ]
366- if dtype .endswith ("4" ):
367- dtype += "x2"
368- quant_dtype = d_dtypes [dtype ]
369-
370- # is_dynamic
371- if input_config :
372- # input_dtype = input_config.get("dtype")
373- # input_qscheme = cast(str, input_config.get("qscheme"))
374- is_dynamic = cast (bool , input_config .get ("is_dynamic" , True ))
375- return LayerQuantConfig (
376- quant_type = quant_type ,
377- quant_dtype = quant_dtype ,
378- is_dynamic = is_dynamic ,
379- quant_method = "quark" ,
380- )
381-
382- # TODO: For now, it's just a temporary migration.
383- # We should subsequently refine them in a targeted manner.
384- def parse_other_config (self ):
385- RE_QUANT_BLOCKSIZE = (
386- r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
387- )
388- orig_quant_config = self .hf_quant_config
389- quant_method = self .quant_method
390- orig_quant_config_str = str (orig_quant_config )
391- if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str :
392- quant_type = QuantType .per_Token
393- elif group_size := re .search (RE_QUANT_BLOCKSIZE , orig_quant_config_str ):
394- group_size = int (group_size .group (1 ))
395- assert group_size in (32 , 128 ), f"Unsupported group size { group_size } "
396- if group_size == 128 :
397- quant_type = QuantType .per_1x128
398- elif group_size == 32 :
399- quant_type = QuantType .per_1x32
400- else :
401- quant_type = QuantType .per_Tensor
402-
403- RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
404- quant_dtype = None
405- m = re .search (RE_QUANT_DTYPE , orig_quant_config_str )
406- if m and m .group (1 ).lower () in [
407- "fp8" ,
408- "fp4" ,
409- "int8" ,
410- "int4" ,
411- "fp8_e4m3" ,
412- "mxfp4" ,
413- ]:
414- dtype = m .group (1 ).lower ().split ("_" )[0 ]
415- if dtype == "mxfp4" :
416- dtype = "fp4"
417- if dtype .endswith ("4" ):
418- dtype += "x2"
419- quant_dtype = d_dtypes [dtype ]
420- else :
421- bit_match = re .search (r"\'(?:num_)?bits\'\:\s*(\d+)" , orig_quant_config_str )
422- if bit_match :
423- bit = int (bit_match .group (1 ))
424- dtype_match = re .search (RE_QUANT_DTYPE , orig_quant_config_str )
425- if dtype_match :
426- dtype = dtype_match .group (1 ).lower ()
427- dtype_prefix = "i" if dtype .startswith ("int" ) else "fp"
428- else :
429- dtype_prefix = "i"
430- quant_dtype_str = (
431- f"{ dtype_prefix } { bit } " if bit != 4 else f"{ dtype_prefix } { bit } x2"
432- )
433- quant_dtype = d_dtypes .get (quant_dtype_str , None )
434- assert (
435- quant_dtype is not None
436- ), f"Cannot parse quant dtype from { orig_quant_config_str } "
437- if quant_dtype == d_dtypes ["fp4x2" ]:
438- quant_type = QuantType .per_1x32
439-
440- RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
441- if re .search (RE_STATIC_QUANT , orig_quant_config_str ):
442- is_dynamic = False
443- else :
444- is_dynamic = True
445- if quant_method == "compressed-tensors" :
446- exclude_layers_key = "ignore"
447- else :
448- logger .warning (
449- f"Using 'ignore' as key for exclude layers with quant_method "
450- f"{ quant_method } , please double check the quantization config."
451- )
452- exclude_layers_key = "ignore"
453- exclude_layers = orig_quant_config .get (exclude_layers_key , [])
454-
455- self .global_quant_config = LayerQuantConfig (
456- quant_type = quant_type ,
457- quant_dtype = quant_dtype ,
458- is_dynamic = is_dynamic ,
459- quant_method = quant_method ,
460- )
461- self .exclude_layers = exclude_layers
370+ # -- internal helpers ---------------------------------------------------
462371
463- def should_ignore_layer_quant (self , layer_name : str ) -> bool :
464- # TODO: solve fused_mapping case
372+ def _is_excluded (self , layer_name : str ) -> bool :
465373 if layer_name is None or not self .exclude_layers :
466374 return False
467375 return any (
468- self .is_equal_or_regex_match (layer_name , ignore_str )
376+ self ._matches_exclude (layer_name , ignore_str )
469377 for ignore_str in self .exclude_layers
470378 )
471379
472- def is_equal_or_regex_match (
473- self , layer_name : str , ignore_str : str , check_contains : bool = False
380+ @staticmethod
381+ def _matches_exclude (
382+ layer_name : str , ignore_str : str , check_contains : bool = False
474383 ) -> bool :
475- """Match the target string or regular expression"""
384+ """Match the target string or regular expression.
385+
386+ Supports exact match, prefix match (layer under an excluded module),
387+ fnmatch glob patterns (``*`` / ``?``), and ``re:`` regex patterns.
388+ """
476389 if ignore_str .startswith ("re:" ):
477- # case "re:model.layers.*self_attn.*", remove the 're:' prefix
478390 pattern = ignore_str [3 :]
479391 if re .search (pattern , layer_name ):
480392 return True
481- # case exclude_layer like "model.layers.0.self_attn.q_a_proj" (dpsk-attn)
482- # a common prefix for linear layers in attn like "model.layers.0.self_attn"
393+ elif "*" in ignore_str or "?" in ignore_str :
394+ # Glob pattern: match exact or as prefix of deeper sub-modules
395+ if fnmatch .fnmatch (layer_name , ignore_str ):
396+ return True
397+ if fnmatch .fnmatch (layer_name , ignore_str + ".*" ):
398+ return True
483399 elif check_contains :
484400 return layer_name .lower () in ignore_str .lower ()
485- elif ignore_str == layer_name :
486- return True
401+ else :
402+ # Exact match or prefix match (e.g. "lm_head" excludes "lm_head.weight")
403+ if layer_name == ignore_str or layer_name .startswith (ignore_str + "." ):
404+ return True
487405 return False
488406
489- def get_layer_quant_config (self , layer_name : str ) -> LayerQuantConfig :
490- if self .should_ignore_layer_quant (layer_name = layer_name ):
491- # return unquantized config
492- return LayerQuantConfig (quant_dtype = self .torch_dtype )
493- # layer quant config
494- layer_quant_config = None
495- if self .layer_quant_config :
496-
497- def _matches_pattern (layer_name , pattern ):
498- if "*" not in pattern :
499- return layer_name in pattern
500- return fnmatch .fnmatch (layer_name , pattern )
501-
502- for name_pattern , config in self .layer_quant_config .items ():
503- if _matches_pattern (layer_name , name_pattern ):
504- layer_quant_config = config
505-
506- layer_quant_config = (
507- self .global_quant_config
508- if layer_quant_config is None
509- else layer_quant_config
510- )
511- # TODO: if use_aiter, we can customize the quantization format here, such as dpsk
512- # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
513- # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs
514-
515- return layer_quant_config
516-
517407 def remap_layer_name (
518408 self , hf_config : PretrainedConfig , packed_modules_mapping : dict | None = None
519409 ):
@@ -556,11 +446,11 @@ def _remap_layer_name(name: str) -> list[str]:
556446 return [name .replace (packed_key , packed_remap_part , 1 )]
557447 return [name ]
558448
559- new_layer_quant_config = {}
560- for layer_name , layer_qconfig in self .layer_quant_config . items () :
561- for remapped in _remap_layer_name (layer_name ):
562- new_layer_quant_config [ remapped ] = layer_qconfig
563- self .layer_quant_config = new_layer_quant_config
449+ new_pattern_specs = []
450+ for pattern , spec in self ._parsed . layer_pattern_specs :
451+ for remapped in _remap_layer_name (pattern ):
452+ new_pattern_specs . append (( remapped , spec ))
453+ self ._parsed . layer_pattern_specs = new_pattern_specs
564454
565455 new_exclude = []
566456 for name in self .exclude_layers :
0 commit comments