Skip to content
33 changes: 33 additions & 0 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import torch.nn as nn
from calib.plugin_calib import PercentileCalibrator

Expand Down Expand Up @@ -98,6 +99,38 @@
"algorithm": {"method": "svdquant", "lowrank": 32},
}

# FLUX Dev mixed-format layers derived from official NVFP4 metadata.
# These patterns target diffusers module names (pre-export), not ComfyUI names.
FLUX_DEV_FP8_LAYER_PATTERNS = [
"transformer_blocks.*.attn.to_q",
"transformer_blocks.*.attn.to_k",
"transformer_blocks.*.attn.to_v",
"transformer_blocks.*.attn.add_q_proj",
"transformer_blocks.*.attn.add_k_proj",
"transformer_blocks.*.attn.add_v_proj",
"transformer_blocks.*.norm1.linear",
"transformer_blocks.*.norm1_context.linear",
"single_transformer_blocks.*.norm.linear",
]


def _build_flux_dev_nvfp4_mixed_config():
cfg = copy.deepcopy(NVFP4_FP8_MHA_CONFIG)
quant_cfg = cfg["quant_cfg"]
fp8_override = {
"num_bits": (4, 3),
"axis": None,
"block_sizes": None,
"enable": True,
}
for layer_pattern in FLUX_DEV_FP8_LAYER_PATTERNS:
quant_cfg[f"*{layer_pattern}*weight_quantizer*"] = fp8_override
quant_cfg[f"*{layer_pattern}*input_quantizer*"] = fp8_override
return cfg


FLUX_DEV_NVFP4_MIXED_CONFIG = _build_flux_dev_nvfp4_mixed_config()


def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **kwargs):
algo_cfg = {"method": quant_algo}
Expand Down
Loading