diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index db95d9547..1dbaf4038 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -864,14 +864,14 @@ def get_fp_layer_names(model: torch.nn.Module, ignore_layers: str): list: A list of layer names that match the specified FP layers or are subcomponents of those layers. """ - from auto_round.utils import SUPPORTED_LAYER_TYPES + from auto_round.utils import INNER_SUPPORTED_LAYER_TYPES, SUPPORTED_LAYER_TYPES if not ignore_layers: return [] ignore_layers = ignore_layers.replace(" ", "").split(",") all_layer_names = [] for n, m in model.named_modules(): - if type(m) in SUPPORTED_LAYER_TYPES: + if type(m) in SUPPORTED_LAYER_TYPES or m.__class__.__name__ in INNER_SUPPORTED_LAYER_TYPES: all_layer_names.append(n) not_to_quantized_layers = [] diff --git a/test/test_cpu/utils/test_compressor_utils.py b/test/test_cpu/utils/test_compressor_utils.py new file mode 100644 index 000000000..1aa77f67c --- /dev/null +++ b/test/test_cpu/utils/test_compressor_utils.py @@ -0,0 +1,124 @@ +""" +Unit tests for auto_round.compressors.utils module. +""" + +import pytest +import torch + +from auto_round.compressors.utils import get_fp_layer_names + + +class TestGetFpLayerNames: + """Test suite for get_fp_layer_names function.""" + + def test_regular_linear_layers(self): + """Test with regular Linear layers.""" + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 10) + self.layer2 = torch.nn.Linear(10, 10) + self.mlp = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + + model = MockModel() + + # Test finding specific layer + result = get_fp_layer_names(model, "layer1") + assert "layer1" in result, "Should find layer1" + + # Test finding layers with pattern + result = get_fp_layer_names(model, "mlp") + assert len(result) == 2, "Should find 2 layers in mlp" + assert "mlp.0" in result and "mlp.1" in result + + def test_fp8linear_layers(self): + """Test with FP8Linear layers (mocked by creating a proper class).""" + + # Create a proper mock FP8Linear class + class FP8Linear(torch.nn.Linear): + """Mock FP8Linear class for testing.""" + + def __init__(self, in_features, out_features): + super().__init__(in_features, out_features) + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 10) + # Use proper FP8Linear mock + self.layer2 = FP8Linear(10, 10) + + self.mlp = torch.nn.Sequential() + linear1 = torch.nn.Linear(10, 10) + self.mlp.add_module("0", linear1) + linear2 = FP8Linear(10, 10) + self.mlp.add_module("1", linear2) + + model = MockModel() + + # Test finding FP8Linear layer + result = get_fp_layer_names(model, "layer2") + assert "layer2" in result, "Should find FP8Linear layer (layer2)" + + # Test finding mixed Linear and FP8Linear in mlp + result = get_fp_layer_names(model, "mlp") + assert len(result) == 2, "Should find 2 layers in mlp (both Linear and FP8Linear)" + assert "mlp.0" in result, "Should find regular Linear in mlp" + assert "mlp.1" in result, "Should find FP8Linear in mlp" + + def test_empty_ignore_layers(self): + """Test with empty ignore_layers string.""" + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 10) + + model = MockModel() + result = get_fp_layer_names(model, "") + assert len(result) == 0, "Empty ignore_layers should return empty list" + + def test_none_ignore_layers(self): + """Test with None as ignore_layers input.""" + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 10) + + model = MockModel() + result = get_fp_layer_names(model, None) + assert len(result) == 0, "None ignore_layers should return empty list" + + def test_multiple_ignore_patterns(self): + """Test with multiple ignore patterns.""" + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 10) + self.layer2 = torch.nn.Linear(10, 10) + self.layer3 = torch.nn.Linear(10, 10) + + model = MockModel() + result = get_fp_layer_names(model, "layer1,layer3") + assert "layer1" in result, "Should find layer1" + assert "layer3" in result, "Should find layer3" + assert "layer2" not in result, "Should not find layer2" + + def test_pattern_with_digits(self): + """Test pattern matching with digits (special case in the code).""" + + class MockModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(10, 10), torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)] + ) + + model = MockModel() + # Pattern ending with digit should get a dot appended for matching + result = get_fp_layer_names(model, "layers.0") + # Should match 'layers.0' + assert "layers.0" in result, "Should match layers.0" diff --git a/test/test_cuda/advanced/test_fp8_input.py b/test/test_cuda/advanced/test_fp8_input.py index ec3dc6bf3..a980d40ec 100644 --- a/test/test_cuda/advanced/test_fp8_input.py +++ b/test/test_cuda/advanced/test_fp8_input.py @@ -141,3 +141,25 @@ def test_diff_datatype(self): ar = AutoRound(model_name, iters=iters, scheme=scheme) ar.quantize_and_save(output_dir=self.save_dir) shutil.rmtree(self.save_dir, ignore_errors=True) + + def test_ignore_layers_fp8(self): + """Test that ignore_layers works correctly with FP8 models.""" + from auto_round.compressors.utils import get_fp_layer_names + + model, tokenizer = self.tiny_fp8_model() + + # Test that get_fp_layer_names can find FP8Linear layers + # Using "mlp" as the ignore pattern which should match mlp layers in the model + layer_names = get_fp_layer_names(model, "mlp") + + # Verify that some layers were found + assert len(layer_names) > 0, "Should find layers matching 'mlp' pattern in FP8 model" + + # Now test with AutoRound using ignore_layers + ar = AutoRound(model=model, tokenizer=tokenizer, iters=0, ignore_layers="mlp") + ar.quantize_and_save(output_dir=self.save_dir) + + # Verify the model was saved successfully + assert os.path.exists(self.save_dir), "Model should be saved" + + shutil.rmtree(self.save_dir, ignore_errors=True)