Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,56 @@ def _has_ver(expr: ast.AST | None) -> bool:
continue
if not is_entrypoint_name(node.name):
continue
decorator_names = {
deco.id
for deco in node.decorator_list
if isinstance(deco, ast.Name)
}
if decorator_names:
for decorator in ast.walk(tree):
if not isinstance(decorator, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if decorator.name not in decorator_names:
continue
outer_params = {arg.arg for arg in decorator.args.args}
for wrapper in ast.walk(decorator):
if wrapper is decorator or not isinstance(wrapper, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
wrapper_params = {arg.arg for arg in wrapper.args.args}
cached_returns = {
_ast_root_name(stmt.value)
for child in ast.walk(wrapper)
if isinstance(child, ast.If)
for stmt in child.body
if isinstance(stmt, ast.Return) and stmt.value is not None
} - {None}
if not cached_returns:
continue
assigns_cache_from_call = {
_ast_root_name(target)
for stmt in ast.walk(wrapper)
if isinstance(stmt, ast.Assign)
and isinstance(stmt.value, ast.Call)
and _ast_root_name(stmt.value.func) in outer_params
and any(_expr_names(arg) & wrapper_params for arg in stmt.value.args)
for target in stmt.targets
} - {None}
returned_after_store = {
_ast_root_name(stmt.value)
for stmt in wrapper.body
if isinstance(stmt, ast.Return) and stmt.value is not None
} - {None}
replay_roots = cached_returns & assigns_cache_from_call & returned_after_store
if replay_roots:
root = next(iter(replay_roots))
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} is decorated by wrapper "
f"that replays cached output slot {root}"
),
}]

signature_features: dict[str, set[str]] = defaultdict(set)
saved_state_features: dict[str, set[str]] = defaultdict(set)
Expand Down