Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,73 @@ def _has_ver(expr: ast.AST | None) -> bool:
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}

default_state = {
arg.arg
for arg, default in zip(node.args.args[-len(node.args.defaults):], node.args.defaults)
if isinstance(default, (ast.List, ast.Tuple)) and len(default.elts) >= 2
}

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue

replay_state = None
for compare in ast.walk(child.test):
if not isinstance(compare, ast.Compare):
continue
operands = [compare.left] + list(compare.comparators)
for i, op in enumerate(compare.ops):
if not isinstance(op, ast.Is):
continue
left = operands[i]
right = operands[i + 1]
left_root = _ast_root_name(left)
right_root = _ast_root_name(right)
if left_root in default_state and right_root in params:
replay_state = left_root
elif right_root in default_state and left_root in params:
replay_state = right_root
if replay_state is None:
continue

returns_default_slot = any(
isinstance(stmt, ast.Return)
and _ast_root_name(stmt.value) == replay_state
for stmt in child.body
)
if not returns_default_slot:
continue

writes_input_slot = False
writes_output_slot = False
for stmt in ast.walk(node):
if not isinstance(stmt, ast.Assign):
continue
targets = stmt.targets
values = [stmt.value] * len(targets)
if len(targets) == 1 and isinstance(targets[0], ast.Tuple) and isinstance(stmt.value, ast.Tuple):
targets = list(targets[0].elts)
values = list(stmt.value.elts)
for target, value in zip(targets, values):
if _ast_root_name(target) != replay_state:
continue
if isinstance(value, ast.Name) and value.id in params:
writes_input_slot = True
elif _expr_names(value) & params:
writes_output_slot = True
if writes_input_slot and writes_output_slot:
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} replays from mutable default state "
"after input identity match"
),
}]

signature_features: dict[str, set[str]] = defaultdict(set)
saved_state_features: dict[str, set[str]] = defaultdict(set)
Expand Down