diff --git a/modelaudit/scanners/pickle_scanner.py b/modelaudit/scanners/pickle_scanner.py index 108f37ff..7e21e46b 100644 --- a/modelaudit/scanners/pickle_scanner.py +++ b/modelaudit/scanners/pickle_scanner.py @@ -6882,14 +6882,31 @@ def extract_metadata(self, file_path: str) -> dict[str, Any]: metadata = super().extract_metadata(file_path) allow_deserialization = bool(self.config.get("allow_metadata_deserialization")) + max_metadata_read_size = int(self.config.get("max_metadata_pickle_read_size", 10 * 1024 * 1024)) try: import pickle import pickletools from io import BytesIO + if max_metadata_read_size <= 0: + raise ValueError( + f"Invalid pickle metadata read limit: {max_metadata_read_size} (must be greater than 0)" + ) + + file_size = self.get_file_size(file_path) + if file_size > max_metadata_read_size: + raise ValueError( + f"Pickle metadata read limit exceeded: {file_size} bytes (max: {max_metadata_read_size})" + ) + with open(file_path, "rb") as f: - pickle_data = f.read() + pickle_data = f.read(max_metadata_read_size + 1) + + if len(pickle_data) > max_metadata_read_size: + raise ValueError( + f"Pickle metadata read limit exceeded: {len(pickle_data)} bytes (max: {max_metadata_read_size})" + ) # Analyze pickle structure metadata.update( diff --git a/tests/test_metadata_extractor.py b/tests/test_metadata_extractor.py index 6ca38ddb..5a98843d 100644 --- a/tests/test_metadata_extractor.py +++ b/tests/test_metadata_extractor.py @@ -466,6 +466,27 @@ def __reduce__(self): assert "REDUCE" in metadata.get("dangerous_opcodes", []) assert metadata.get("has_dangerous_opcodes") is True + @pytest.mark.parametrize( + ("limit", "expected_error"), + [ + (64, "read limit exceeded"), + (0, "must be greater than 0"), + (-1, "must be greater than 0"), + ], + ) + def test_pickle_metadata_enforces_read_limit(self, tmp_path: Path, limit: int, expected_error: str) -> None: + """Ensure pickle metadata extraction rejects oversized and invalid read limits.""" + from modelaudit.scanners.pickle_scanner import PickleScanner + + pkl_file = tmp_path / "oversized.pkl" + pkl_file.write_bytes(b"x" * 128) + + scanner = PickleScanner({"max_metadata_pickle_read_size": limit}) + metadata = scanner.extract_metadata(str(pkl_file)) + + assert "extraction_error" in metadata + assert expected_error in metadata["extraction_error"] + def test_pickle_safe_data_no_dangerous_opcodes(self, tmp_path: Path) -> None: """Ensure simple data structures don't trigger dangerous opcode detection.""" from modelaudit.scanners.pickle_scanner import PickleScanner