88import os
99import re
1010from dataclasses import dataclass , field
11- from typing import Any , cast , Optional , Union
11+ from typing import Any , Optional , Union
1212
1313import torch
1414from aiter import QuantType
15- from aiter .utility .dtypes import d_dtypes
15+ from atom .quant_spec import (
16+ LayerQuantConfig ,
17+ ParsedQuantConfig ,
18+ get_quant_parser ,
19+ )
1620from atom .utils import envs , get_open_port
1721from atom .utils .distributed .utils import stateless_init_torch_distributed_process_group
1822from torch .distributed import ProcessGroup , ReduceOp
@@ -251,70 +255,92 @@ def set_splitting_ops_for_v1(self):
251255 ]
252256
253257
254- class LayerQuantConfig (dict ):
255- def __init__ (
256- self ,
257- quant_type = QuantType .No ,
258- quant_dtype = torch .bfloat16 ,
259- is_dynamic = True ,
260- quant_method = "" ,
261- ):
262- """
263- Core components of layer_quant
264- """
265- super ().__init__ ()
266- self ["quant_type" ] = quant_type if quant_type is not None else QuantType .No
267- self ["quant_dtype" ] = quant_dtype if quant_dtype is not None else torch .bfloat16
268- self ["is_dynamic" ] = is_dynamic
269- self ["quant_method" ] = quant_method
258+ class QuantizationConfig :
259+ """Model-wide quantization configuration.
270260
261+ API:
262+ - ``get_layer_quant_config(prefix)`` -> :class:`LayerQuantConfig`
263+ - ``global_quant_config`` property -> :class:`LayerQuantConfig`
264+ - ``quant_type``, ``quant_dtype``, ``is_dynamic`` convenience properties
265+ """
271266
272- class QuantizationConfig :
273267 def __init__ (self , config : PretrainedConfig = None ):
274268 if config is None :
275269 self .torch_dtype = torch .bfloat16
276270 self .hf_quant_config = None
277- self .global_quant_config = LayerQuantConfig ()
278- self .layer_quant_config = {}
279- self .exclude_layers = []
271+ self .global_spec : LayerQuantConfig = LayerQuantConfig ()
272+ self .layer_pattern_specs : list [ tuple [ str , LayerQuantConfig ]] = []
273+ self .exclude_layers : list [ str ] = []
280274 self .quant_method = ""
281275 return
282276
283- self .torch_dtype = getattr (config , "torch_dtype" , torch .bfloat16 )
277+ # Some HF configs set torch_dtype=None; normalize to bf16 default.
278+ self .torch_dtype = getattr (config , "torch_dtype" , None ) or torch .bfloat16
284279 self .hf_quant_config = getattr (config , "quantization_config" , None )
285- self .global_quant_config = None
286- self .layer_quant_config = {}
287- self .exclude_layers = []
288280
289281 if self .hf_quant_config is None :
290- self .global_quant_config = LayerQuantConfig (
282+ self .global_spec = LayerQuantConfig (
291283 quant_type = QuantType .No , quant_dtype = self .torch_dtype
292284 )
285+ self .layer_pattern_specs = []
286+ self .exclude_layers = []
293287 self .quant_method = ""
294288 return
295289
296290 self .quant_method = self .hf_quant_config .get ("quant_method" , "" )
297- if self .quant_method == "quark" :
298- layer_quant_config_dict = cast (
299- dict [str , Any ], self .hf_quant_config .get ("layer_quant_config" , {})
300- )
301- for layer_name , layer_cfg in layer_quant_config_dict .items ():
302- self .layer_quant_config [layer_name ] = self .parse_quark_config_dict (
303- layer_cfg
304- )
305291
306- global_quant_config_dict = cast (
307- dict [ str , Any ], self .hf_quant_config . get ( "global_quant_config" , {} )
308- )
309- self .global_quant_config = self . parse_quark_config_dict (
310- global_quant_config_dict
311- )
292+ # Use the parser registry to build a structured ParsedQuantConfig
293+ parser = get_quant_parser ( self .quant_method )
294+ parsed_quant_config = parser . parse ( self . hf_quant_config )
295+ self .global_spec = parsed_quant_config . global_spec
296+ self . layer_pattern_specs = parsed_quant_config . layer_pattern_specs
297+ self . exclude_layers = list ( parsed_quant_config . exclude_layers )
312298
313- self .exclude_layers = cast (
314- list [str ], self .hf_quant_config .get ("exclude" , [])
315- )
316- else :
317- self .parse_other_config ()
299+ # -- typed API (preferred) ----------------------------------------------
300+
301+ @property
302+ def global_quant_config (self ) -> LayerQuantConfig :
303+ """Alias for ``global_spec``."""
304+ return self .global_spec
305+
306+ def get_layer_quant_config (self , layer_name : str ) -> LayerQuantConfig :
307+ """Return the :class:`LayerQuantConfig` for *layer_name*.
308+
309+ Resolution order:
310+ 1. Check exclude list -> ``LayerQuantConfig.no_quant()``.
311+ 2. fnmatch-style pattern match in ``layer_pattern_specs``.
312+ 3. Fall back to ``global_spec``.
313+ """
314+ # 1. Exclude list
315+ if self ._is_excluded (layer_name ):
316+ return LayerQuantConfig (quant_dtype = self .torch_dtype )
317+
318+ # 2. Pattern match
319+ for pattern , spec in self .layer_pattern_specs :
320+ if "*" not in pattern :
321+ if layer_name in pattern :
322+ return spec
323+ elif fnmatch .fnmatch (layer_name , pattern ):
324+ return spec
325+
326+ # 3. Global default
327+ return self .global_spec
328+
329+ # -- convenience properties (delegate to global_spec) ---------------------
330+
331+ @property
332+ def quant_type (self ) -> QuantType :
333+ return self .global_spec .quant_type
334+
335+ @property
336+ def quant_dtype (self ) -> torch .dtype :
337+ return self .global_spec .quant_dtype
338+
339+ @property
340+ def is_dynamic (self ) -> bool :
341+ return self .global_spec .is_dynamic
342+
343+ # -- other methods ------------------------------------------------------
318344
319345 def compute_hash (self ) -> str :
320346 """
@@ -329,191 +355,53 @@ def compute_hash(self) -> str:
329355 the final hidden states.
330356 """
331357 factors : list [Any ] = []
332- factors .append (self .global_quant_config )
333- factors .append (self .layer_quant_config )
358+ factors .append (self .global_spec )
359+ factors .append (self .layer_pattern_specs )
334360 factors .append (self .exclude_layers )
335361 hash_value = hashlib .sha256 (str (factors ).encode ()).hexdigest ()
336362 return hash_value
337363
338364 def get_name (self ):
339- """
340- Returns the quantization method name.
341- """
365+ """Returns the quantization method name."""
342366 return self .quant_method
343367
344- def parse_quark_config_dict (self , config : dict ) -> LayerQuantConfig :
345- quant_type = None
346- quant_dtype = None
347- is_dynamic = True
348- weight_config = cast (dict [str , Any ], config .get ("weight" , {}))
349- input_config = cast (dict [str , Any ], config .get ("input_tensors" , {}))
350- weight_qscheme = cast (str , weight_config .get ("qscheme" , "" ))
351- weight_dtype = weight_config .get ("dtype" , "" )
352-
353- # quant_type
354- if weight_qscheme == "per_channel" :
355- quant_type = QuantType .per_Token
356- elif weight_qscheme == "per_tensor" :
357- quant_type = QuantType .per_Tensor
358- elif weight_qscheme == "per_group" :
359- # Currently, quark only supports group_size=32
360- quant_type = QuantType .per_1x32
361- else :
362- quant_type = QuantType .No
363-
364- # quant_dtype
365- dtype = weight_dtype .split ("_" )[0 ]
366- if dtype .endswith ("4" ):
367- dtype += "x2"
368- quant_dtype = d_dtypes [dtype ]
369-
370- # is_dynamic
371- if input_config :
372- # input_dtype = input_config.get("dtype")
373- # input_qscheme = cast(str, input_config.get("qscheme"))
374- is_dynamic = cast (bool , input_config .get ("is_dynamic" , True ))
375- return LayerQuantConfig (
376- quant_type = quant_type ,
377- quant_dtype = quant_dtype ,
378- is_dynamic = is_dynamic ,
379- quant_method = "quark" ,
380- )
381-
382- # TODO: For now, it's just a temporary migration.
383- # We should subsequently refine them in a targeted manner.
384- def parse_other_config (self ):
385- RE_QUANT_BLOCKSIZE = (
386- r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
387- )
388- orig_quant_config = self .hf_quant_config
389- quant_method = self .quant_method
390- orig_quant_config_str = str (orig_quant_config )
391- if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str :
392- quant_type = QuantType .per_Token
393- elif group_size := re .search (RE_QUANT_BLOCKSIZE , orig_quant_config_str ):
394- group_size = int (group_size .group (1 ))
395- assert group_size in (32 , 128 ), f"Unsupported group size { group_size } "
396- if group_size == 128 :
397- quant_type = QuantType .per_1x128
398- elif group_size == 32 :
399- quant_type = QuantType .per_1x32
400- else :
401- quant_type = QuantType .per_Tensor
402-
403- RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
404- quant_dtype = None
405- m = re .search (RE_QUANT_DTYPE , orig_quant_config_str )
406- if m and m .group (1 ).lower () in [
407- "fp8" ,
408- "fp4" ,
409- "int8" ,
410- "int4" ,
411- "fp8_e4m3" ,
412- "mxfp4" ,
413- ]:
414- dtype = m .group (1 ).lower ().split ("_" )[0 ]
415- if dtype == "mxfp4" :
416- dtype = "fp4"
417- if dtype .endswith ("4" ):
418- dtype += "x2"
419- quant_dtype = d_dtypes [dtype ]
420- else :
421- bit_match = re .search (r"\'(?:num_)?bits\'\:\s*(\d+)" , orig_quant_config_str )
422- if bit_match :
423- bit = int (bit_match .group (1 ))
424- dtype_match = re .search (RE_QUANT_DTYPE , orig_quant_config_str )
425- if dtype_match :
426- dtype = dtype_match .group (1 ).lower ()
427- dtype_prefix = "i" if dtype .startswith ("int" ) else "fp"
428- else :
429- dtype_prefix = "i"
430- quant_dtype_str = (
431- f"{ dtype_prefix } { bit } " if bit != 4 else f"{ dtype_prefix } { bit } x2"
432- )
433- quant_dtype = d_dtypes .get (quant_dtype_str , None )
434- assert (
435- quant_dtype is not None
436- ), f"Cannot parse quant dtype from { orig_quant_config_str } "
437- if quant_dtype == d_dtypes ["fp4x2" ]:
438- quant_type = QuantType .per_1x32
439-
440- RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
441- if re .search (RE_STATIC_QUANT , orig_quant_config_str ):
442- is_dynamic = False
443- else :
444- is_dynamic = True
445- if quant_method == "compressed-tensors" :
446- exclude_layers_key = "ignore"
447- else :
448- logger .warning (
449- f"Using 'ignore' as key for exclude layers with quant_method "
450- f"{ quant_method } , please double check the quantization config."
451- )
452- exclude_layers_key = "ignore"
453- exclude_layers = orig_quant_config .get (exclude_layers_key , [])
454-
455- self .global_quant_config = LayerQuantConfig (
456- quant_type = quant_type ,
457- quant_dtype = quant_dtype ,
458- is_dynamic = is_dynamic ,
459- quant_method = quant_method ,
460- )
461- self .exclude_layers = exclude_layers
368+ # -- internal helpers ---------------------------------------------------
462369
463- def should_ignore_layer_quant (self , layer_name : str ) -> bool :
464- # TODO: solve fused_mapping case
370+ def _is_excluded (self , layer_name : str ) -> bool :
465371 if layer_name is None or not self .exclude_layers :
466372 return False
467373 return any (
468- self .is_equal_or_regex_match (layer_name , ignore_str )
374+ self ._matches_exclude (layer_name , ignore_str )
469375 for ignore_str in self .exclude_layers
470376 )
471377
472- def is_equal_or_regex_match (
473- self , layer_name : str , ignore_str : str , check_contains : bool = False
378+ @staticmethod
379+ def _matches_exclude (
380+ layer_name : str , ignore_str : str , check_contains : bool = False
474381 ) -> bool :
475- """Match the target string or regular expression"""
382+ """Match the target string or regular expression.
383+
384+ Supports exact match, prefix match (layer under an excluded module),
385+ fnmatch glob patterns (``*`` / ``?``), and ``re:`` regex patterns.
386+ """
476387 if ignore_str .startswith ("re:" ):
477- # case "re:model.layers.*self_attn.*", remove the 're:' prefix
478388 pattern = ignore_str [3 :]
479389 if re .search (pattern , layer_name ):
480390 return True
481- # case exclude_layer like "model.layers.0.self_attn.q_a_proj" (dpsk-attn)
482- # a common prefix for linear layers in attn like "model.layers.0.self_attn"
391+ elif "*" in ignore_str or "?" in ignore_str :
392+ # Glob pattern: match exact or as prefix of deeper sub-modules
393+ if fnmatch .fnmatch (layer_name , ignore_str ):
394+ return True
395+ if fnmatch .fnmatch (layer_name , ignore_str + ".*" ):
396+ return True
483397 elif check_contains :
484398 return layer_name .lower () in ignore_str .lower ()
485- elif ignore_str == layer_name :
486- return True
399+ else :
400+ # Exact match or prefix match (e.g. "lm_head" excludes "lm_head.weight")
401+ if layer_name == ignore_str or layer_name .startswith (ignore_str + "." ):
402+ return True
487403 return False
488404
489- def get_layer_quant_config (self , layer_name : str ) -> LayerQuantConfig :
490- if self .should_ignore_layer_quant (layer_name = layer_name ):
491- # return unquantized config
492- return LayerQuantConfig (quant_dtype = self .torch_dtype )
493- # layer quant config
494- layer_quant_config = None
495- if self .layer_quant_config :
496-
497- def _matches_pattern (layer_name , pattern ):
498- if "*" not in pattern :
499- return layer_name in pattern
500- return fnmatch .fnmatch (layer_name , pattern )
501-
502- for name_pattern , config in self .layer_quant_config .items ():
503- if _matches_pattern (layer_name , name_pattern ):
504- layer_quant_config = config
505-
506- layer_quant_config = (
507- self .global_quant_config
508- if layer_quant_config is None
509- else layer_quant_config
510- )
511- # TODO: if use_aiter, we can customize the quantization format here, such as dpsk
512- # For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
513- # For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs
514-
515- return layer_quant_config
516-
517405 def remap_layer_name (
518406 self , hf_config : PretrainedConfig , packed_modules_mapping : dict | None = None
519407 ):
@@ -556,11 +444,11 @@ def _remap_layer_name(name: str) -> list[str]:
556444 return [name .replace (packed_key , packed_remap_part , 1 )]
557445 return [name ]
558446
559- new_layer_quant_config = {}
560- for layer_name , layer_qconfig in self .layer_quant_config . items () :
561- for remapped in _remap_layer_name (layer_name ):
562- new_layer_quant_config [ remapped ] = layer_qconfig
563- self .layer_quant_config = new_layer_quant_config
447+ new_pattern_specs = []
448+ for pattern , spec in self .layer_pattern_specs :
449+ for remapped in _remap_layer_name (pattern ):
450+ new_pattern_specs . append (( remapped , spec ))
451+ self .layer_pattern_specs = new_pattern_specs
564452
565453 new_exclude = []
566454 for name in self .exclude_layers :
0 commit comments