diff --git a/inference_lib/src/fp_quant/utils/replace.py b/inference_lib/src/fp_quant/utils/replace.py index 842733c..3aef831 100644 --- a/inference_lib/src/fp_quant/utils/replace.py +++ b/inference_lib/src/fp_quant/utils/replace.py @@ -1,4 +1,4 @@ -import torch +import re from torch import nn from .config import FPQuantConfig @@ -40,7 +40,7 @@ def replace_with_fp_quant_linear( # Check if the current key is not in the `quantization_config.modules_to_not_convert` current_key_name_str = ".".join(current_key_name) if not any( - current_key_name_str.endswith(key) + re.search(key, current_key_name_str) is not None for key in fp_quant_linear_config.modules_to_not_convert ): with init_empty_weights():