Skip to content

Commit 0ce71af

Browse files
committed
Refactor: typed LayerQuantConfig (frozen dataclass) + parser registry for per-layer quant config
Introduce atom/quant_spec.py with: - LayerQuantConfig: frozen dataclass with typed attribute access (quant_type, quant_dtype, is_dynamic, quant_method) replacing the old dict-based LayerQuantConfig(dict) subclass - ParsedQuantConfig: structured output of HF config parsing - Parser registry (@register_quant_parser) with QuarkParser and GenericParser (fallback for compressed-tensors, GPTQ, AWQ, etc.) Refactor QuantizationConfig (atom/config.py): - Internal storage now uses ParsedQuantConfig via parser registry - get_layer_quant_config(prefix) -> LayerQuantConfig (frozen dataclass) - global_quant_config property -> LayerQuantConfig - Convenience properties: quant_type, quant_dtype, is_dynamic - compute_hash() uses typed internal structures Migrate all consumers to typed attribute access: - linear.py: layer_quant_config.quant_type instead of ["quant_type"] - moe.py: all MoE method classes use LayerQuantConfig type hints - activation.py, layernorm.py: accept prefix param, use get_layer_quant_config() instead of bypassing with global_quant_config - deepseek_mtp.py, deepseek_v2.py, llama.py: use get_layer_quant_config() Fix GenericParser exclude-layer key handling (atom/quant_spec.py): - Different quantizers use different keys for excluded layers: compressed-tensors uses "ignore", gpt-oss/HF uses "modules_to_not_convert", Quark uses "exclude" - GenericParser now tries all three keys in priority order so excluded layers are never silently treated as quantized Fix hard-coded quant_config=None across models: - gpt_oss.py OAIAttention: qkv_proj and o_proj were passing quant_config=None, preventing fp8/mxfp4 quantization on attention projections in Quark gpt-oss models (e.g. fp8 qkv + mxfp4 MoE); both now receive quant_config - deepseek_v2.py Indexer: weights_proj passed quant_config=None while sibling linears wq_b and wk correctly used quant_config; fixed for consistency - qwen3_next.py GatedDeltaNet: conv1d ColumnParallelLinear omitted quant_config while other linears in the same class passed it; fixed
1 parent cbca4c1 commit 0ce71af

14 files changed

Lines changed: 751 additions & 422 deletions

atom/config.py

Lines changed: 102 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import os
99
import re
1010
from dataclasses import dataclass, field
11-
from typing import Any, cast, Optional, Union
11+
from typing import Any, Optional, Union
1212

1313
import torch
1414
from aiter import QuantType
15-
from aiter.utility.dtypes import d_dtypes
15+
from atom.quant_spec import (
16+
LayerQuantConfig,
17+
ParsedQuantConfig,
18+
get_quant_parser,
19+
)
1620
from atom.utils import envs, get_open_port
1721
from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group
1822
from torch.distributed import ProcessGroup, ReduceOp
@@ -251,70 +255,94 @@ def set_splitting_ops_for_v1(self):
251255
]
252256

253257

254-
class LayerQuantConfig(dict):
255-
def __init__(
256-
self,
257-
quant_type=QuantType.No,
258-
quant_dtype=torch.bfloat16,
259-
is_dynamic=True,
260-
quant_method="",
261-
):
262-
"""
263-
Core components of layer_quant
264-
"""
265-
super().__init__()
266-
self["quant_type"] = quant_type if quant_type is not None else QuantType.No
267-
self["quant_dtype"] = quant_dtype if quant_dtype is not None else torch.bfloat16
268-
self["is_dynamic"] = is_dynamic
269-
self["quant_method"] = quant_method
258+
class QuantizationConfig:
259+
"""Model-wide quantization configuration.
270260
261+
API:
262+
- ``get_layer_quant_config(prefix)`` -> :class:`LayerQuantConfig`
263+
- ``global_quant_config`` property -> :class:`LayerQuantConfig`
264+
- ``quant_type``, ``quant_dtype``, ``is_dynamic`` convenience properties
265+
"""
271266

272-
class QuantizationConfig:
273267
def __init__(self, config: PretrainedConfig = None):
274268
if config is None:
275269
self.torch_dtype = torch.bfloat16
276270
self.hf_quant_config = None
277-
self.global_quant_config = LayerQuantConfig()
278-
self.layer_quant_config = {}
279-
self.exclude_layers = []
271+
self._parsed = ParsedQuantConfig()
272+
self.exclude_layers: list[str] = []
280273
self.quant_method = ""
281274
return
282275

283276
self.torch_dtype = getattr(config, "torch_dtype", torch.bfloat16)
284277
self.hf_quant_config = getattr(config, "quantization_config", None)
285-
self.global_quant_config = None
286-
self.layer_quant_config = {}
287278
self.exclude_layers = []
288279

289280
if self.hf_quant_config is None:
290-
self.global_quant_config = LayerQuantConfig(
291-
quant_type=QuantType.No, quant_dtype=self.torch_dtype
281+
self._parsed = ParsedQuantConfig(
282+
global_spec=LayerQuantConfig(
283+
quant_type=QuantType.No, quant_dtype=self.torch_dtype
284+
)
292285
)
293286
self.quant_method = ""
294287
return
295288

296289
self.quant_method = self.hf_quant_config.get("quant_method", "")
297-
if self.quant_method == "quark":
298-
layer_quant_config_dict = cast(
299-
dict[str, Any], self.hf_quant_config.get("layer_quant_config", {})
300-
)
301-
for layer_name, layer_cfg in layer_quant_config_dict.items():
302-
self.layer_quant_config[layer_name] = self.parse_quark_config_dict(
303-
layer_cfg
304-
)
305290

306-
global_quant_config_dict = cast(
307-
dict[str, Any], self.hf_quant_config.get("global_quant_config", {})
308-
)
309-
self.global_quant_config = self.parse_quark_config_dict(
310-
global_quant_config_dict
311-
)
291+
# Use the parser registry to build a structured ParsedQuantConfig
292+
parser = get_quant_parser(self.quant_method)
293+
self._parsed = parser.parse(self.hf_quant_config)
294+
self.exclude_layers = list(self._parsed.exclude_layers)
312295

313-
self.exclude_layers = cast(
314-
list[str], self.hf_quant_config.get("exclude", [])
315-
)
316-
else:
317-
self.parse_other_config()
296+
# -- typed API (preferred) ----------------------------------------------
297+
298+
@property
299+
def global_quant_config(self) -> LayerQuantConfig:
300+
"""The default quantization spec for all layers."""
301+
return self._parsed.global_spec
302+
303+
def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig:
304+
"""Return the :class:`LayerQuantConfig` for *layer_name*.
305+
306+
Resolution order:
307+
1. Check exclude list -> ``LayerQuantConfig.no_quant()``.
308+
2. Exact match in ``parsed.layer_specs``.
309+
3. fnmatch-style pattern match in ``parsed.layer_pattern_specs``.
310+
4. Fall back to ``global_quant_config``.
311+
"""
312+
# 1. Exclude list
313+
if self._is_excluded(layer_name):
314+
return LayerQuantConfig(quant_dtype=self.torch_dtype)
315+
316+
# 2. Exact match
317+
if layer_name in self._parsed.layer_specs:
318+
return self._parsed.layer_specs[layer_name]
319+
320+
# 3. Pattern match
321+
for pattern, spec in self._parsed.layer_pattern_specs:
322+
if "*" not in pattern:
323+
if layer_name in pattern:
324+
return spec
325+
elif fnmatch.fnmatch(layer_name, pattern):
326+
return spec
327+
328+
# 4. Global default
329+
return self._parsed.global_spec
330+
331+
# -- convenience properties (delegate to global_quant_config) -------------
332+
333+
@property
334+
def quant_type(self) -> QuantType:
335+
return self._parsed.global_spec.quant_type
336+
337+
@property
338+
def quant_dtype(self) -> torch.dtype:
339+
return self._parsed.global_spec.quant_dtype
340+
341+
@property
342+
def is_dynamic(self) -> bool:
343+
return self._parsed.global_spec.is_dynamic
344+
345+
# -- other methods ------------------------------------------------------
318346

319347
def compute_hash(self) -> str:
320348
"""
@@ -329,191 +357,53 @@ def compute_hash(self) -> str:
329357
the final hidden states.
330358
"""
331359
factors: list[Any] = []
332-
factors.append(self.global_quant_config)
333-
factors.append(self.layer_quant_config)
360+
factors.append(self._parsed.global_spec)
361+
factors.append(self._parsed.layer_pattern_specs)
334362
factors.append(self.exclude_layers)
335363
hash_value = hashlib.sha256(str(factors).encode()).hexdigest()
336364
return hash_value
337365

338366
def get_name(self):
339-
"""
340-
Returns the quantization method name.
341-
"""
367+
"""Returns the quantization method name."""
342368
return self.quant_method
343369

344-
def parse_quark_config_dict(self, config: dict) -> LayerQuantConfig:
345-
quant_type = None
346-
quant_dtype = None
347-
is_dynamic = True
348-
weight_config = cast(dict[str, Any], config.get("weight", {}))
349-
input_config = cast(dict[str, Any], config.get("input_tensors", {}))
350-
weight_qscheme = cast(str, weight_config.get("qscheme", ""))
351-
weight_dtype = weight_config.get("dtype", "")
352-
353-
# quant_type
354-
if weight_qscheme == "per_channel":
355-
quant_type = QuantType.per_Token
356-
elif weight_qscheme == "per_tensor":
357-
quant_type = QuantType.per_Tensor
358-
elif weight_qscheme == "per_group":
359-
# Currently, quark only supports group_size=32
360-
quant_type = QuantType.per_1x32
361-
else:
362-
quant_type = QuantType.No
363-
364-
# quant_dtype
365-
dtype = weight_dtype.split("_")[0]
366-
if dtype.endswith("4"):
367-
dtype += "x2"
368-
quant_dtype = d_dtypes[dtype]
369-
370-
# is_dynamic
371-
if input_config:
372-
# input_dtype = input_config.get("dtype")
373-
# input_qscheme = cast(str, input_config.get("qscheme"))
374-
is_dynamic = cast(bool, input_config.get("is_dynamic", True))
375-
return LayerQuantConfig(
376-
quant_type=quant_type,
377-
quant_dtype=quant_dtype,
378-
is_dynamic=is_dynamic,
379-
quant_method="quark",
380-
)
381-
382-
# TODO: For now, it's just a temporary migration.
383-
# We should subsequently refine them in a targeted manner.
384-
def parse_other_config(self):
385-
RE_QUANT_BLOCKSIZE = (
386-
r"\'(?:group_size|weight_block_size)\'\:\s*(?:\[\n*)\s*(\d+),"
387-
)
388-
orig_quant_config = self.hf_quant_config
389-
quant_method = self.quant_method
390-
orig_quant_config_str = str(orig_quant_config)
391-
if quant_method == "compressed-tensors" or "channel'," in orig_quant_config_str:
392-
quant_type = QuantType.per_Token
393-
elif group_size := re.search(RE_QUANT_BLOCKSIZE, orig_quant_config_str):
394-
group_size = int(group_size.group(1))
395-
assert group_size in (32, 128), f"Unsupported group size {group_size}"
396-
if group_size == 128:
397-
quant_type = QuantType.per_1x128
398-
elif group_size == 32:
399-
quant_type = QuantType.per_1x32
400-
else:
401-
quant_type = QuantType.per_Tensor
402-
403-
RE_QUANT_DTYPE = r"\'(?:d?type|weight_dtype|quant_method)\'\:\s*\'(\w+)\'"
404-
quant_dtype = None
405-
m = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
406-
if m and m.group(1).lower() in [
407-
"fp8",
408-
"fp4",
409-
"int8",
410-
"int4",
411-
"fp8_e4m3",
412-
"mxfp4",
413-
]:
414-
dtype = m.group(1).lower().split("_")[0]
415-
if dtype == "mxfp4":
416-
dtype = "fp4"
417-
if dtype.endswith("4"):
418-
dtype += "x2"
419-
quant_dtype = d_dtypes[dtype]
420-
else:
421-
bit_match = re.search(r"\'(?:num_)?bits\'\:\s*(\d+)", orig_quant_config_str)
422-
if bit_match:
423-
bit = int(bit_match.group(1))
424-
dtype_match = re.search(RE_QUANT_DTYPE, orig_quant_config_str)
425-
if dtype_match:
426-
dtype = dtype_match.group(1).lower()
427-
dtype_prefix = "i" if dtype.startswith("int") else "fp"
428-
else:
429-
dtype_prefix = "i"
430-
quant_dtype_str = (
431-
f"{dtype_prefix}{bit}" if bit != 4 else f"{dtype_prefix}{bit}x2"
432-
)
433-
quant_dtype = d_dtypes.get(quant_dtype_str, None)
434-
assert (
435-
quant_dtype is not None
436-
), f"Cannot parse quant dtype from {orig_quant_config_str}"
437-
if quant_dtype == d_dtypes["fp4x2"]:
438-
quant_type = QuantType.per_1x32
439-
440-
RE_STATIC_QUANT = r"\'(?:activation_scheme)\'\:\s*\'(static)\'"
441-
if re.search(RE_STATIC_QUANT, orig_quant_config_str):
442-
is_dynamic = False
443-
else:
444-
is_dynamic = True
445-
if quant_method == "compressed-tensors":
446-
exclude_layers_key = "ignore"
447-
else:
448-
logger.warning(
449-
f"Using 'ignore' as key for exclude layers with quant_method "
450-
f"{quant_method}, please double check the quantization config."
451-
)
452-
exclude_layers_key = "ignore"
453-
exclude_layers = orig_quant_config.get(exclude_layers_key, [])
454-
455-
self.global_quant_config = LayerQuantConfig(
456-
quant_type=quant_type,
457-
quant_dtype=quant_dtype,
458-
is_dynamic=is_dynamic,
459-
quant_method=quant_method,
460-
)
461-
self.exclude_layers = exclude_layers
370+
# -- internal helpers ---------------------------------------------------
462371

463-
def should_ignore_layer_quant(self, layer_name: str) -> bool:
464-
# TODO: solve fused_mapping case
372+
def _is_excluded(self, layer_name: str) -> bool:
465373
if layer_name is None or not self.exclude_layers:
466374
return False
467375
return any(
468-
self.is_equal_or_regex_match(layer_name, ignore_str)
376+
self._matches_exclude(layer_name, ignore_str)
469377
for ignore_str in self.exclude_layers
470378
)
471379

472-
def is_equal_or_regex_match(
473-
self, layer_name: str, ignore_str: str, check_contains: bool = False
380+
@staticmethod
381+
def _matches_exclude(
382+
layer_name: str, ignore_str: str, check_contains: bool = False
474383
) -> bool:
475-
"""Match the target string or regular expression"""
384+
"""Match the target string or regular expression.
385+
386+
Supports exact match, prefix match (layer under an excluded module),
387+
fnmatch glob patterns (``*`` / ``?``), and ``re:`` regex patterns.
388+
"""
476389
if ignore_str.startswith("re:"):
477-
# case "re:model.layers.*self_attn.*", remove the 're:' prefix
478390
pattern = ignore_str[3:]
479391
if re.search(pattern, layer_name):
480392
return True
481-
# case exclude_layer like "model.layers.0.self_attn.q_a_proj" (dpsk-attn)
482-
# a common prefix for linear layers in attn like "model.layers.0.self_attn"
393+
elif "*" in ignore_str or "?" in ignore_str:
394+
# Glob pattern: match exact or as prefix of deeper sub-modules
395+
if fnmatch.fnmatch(layer_name, ignore_str):
396+
return True
397+
if fnmatch.fnmatch(layer_name, ignore_str + ".*"):
398+
return True
483399
elif check_contains:
484400
return layer_name.lower() in ignore_str.lower()
485-
elif ignore_str == layer_name:
486-
return True
401+
else:
402+
# Exact match or prefix match (e.g. "lm_head" excludes "lm_head.weight")
403+
if layer_name == ignore_str or layer_name.startswith(ignore_str + "."):
404+
return True
487405
return False
488406

489-
def get_layer_quant_config(self, layer_name: str) -> LayerQuantConfig:
490-
if self.should_ignore_layer_quant(layer_name=layer_name):
491-
# return unquantized config
492-
return LayerQuantConfig(quant_dtype=self.torch_dtype)
493-
# layer quant config
494-
layer_quant_config = None
495-
if self.layer_quant_config:
496-
497-
def _matches_pattern(layer_name, pattern):
498-
if "*" not in pattern:
499-
return layer_name in pattern
500-
return fnmatch.fnmatch(layer_name, pattern)
501-
502-
for name_pattern, config in self.layer_quant_config.items():
503-
if _matches_pattern(layer_name, name_pattern):
504-
layer_quant_config = config
505-
506-
layer_quant_config = (
507-
self.global_quant_config
508-
if layer_quant_config is None
509-
else layer_quant_config
510-
)
511-
# TODO: if use_aiter, we can customize the quantization format here, such as dpsk
512-
# For FP4 and use_triton_gemm(), fused_qkv_a_proj and q_b_proj are AITER-Triton FP4 GEMMs but o_proj remains AITER BF16 GEMMs,
513-
# For FP8 and use_triton_gemm(), fused_qkv_a_proj is AITER-Triton FP8 GEMMs while others remain AITER FP8 GEMMs
514-
515-
return layer_quant_config
516-
517407
def remap_layer_name(
518408
self, hf_config: PretrainedConfig, packed_modules_mapping: dict | None = None
519409
):
@@ -556,11 +446,11 @@ def _remap_layer_name(name: str) -> list[str]:
556446
return [name.replace(packed_key, packed_remap_part, 1)]
557447
return [name]
558448

559-
new_layer_quant_config = {}
560-
for layer_name, layer_qconfig in self.layer_quant_config.items():
561-
for remapped in _remap_layer_name(layer_name):
562-
new_layer_quant_config[remapped] = layer_qconfig
563-
self.layer_quant_config = new_layer_quant_config
449+
new_pattern_specs = []
450+
for pattern, spec in self._parsed.layer_pattern_specs:
451+
for remapped in _remap_layer_name(pattern):
452+
new_pattern_specs.append((remapped, spec))
453+
self._parsed.layer_pattern_specs = new_pattern_specs
564454

565455
new_exclude = []
566456
for name in self.exclude_layers:

0 commit comments

Comments
 (0)