Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,73 @@ def _is_data_ptr_call(expr: ast.AST | None) -> bool:
and expr.func.attr == "data_ptr"
)

def _data_ptr_receiver(expr: ast.AST | None) -> Optional[str]:
if not _is_data_ptr_call(expr):
return None
return _ast_root_name(expr.func.value)

for class_node in ast.walk(tree):
if not isinstance(class_node, ast.ClassDef) or not is_entrypoint_name(class_node.name):
continue
methods = [
stmt for stmt in class_node.body
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and stmt.name == "__call__"
]
for method in methods:
params = {arg.arg for arg in method.args.args}
params.discard("self")
state_root = None
for child in ast.walk(method):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
returned_roots = {
_ast_root_name(stmt.value)
for stmt in child.body
if isinstance(stmt, ast.Return) and stmt.value is not None
} - {None}
if not returned_roots:
continue

compared_state_roots: set[str] = set()
for compare in ast.walk(child.test):
if not isinstance(compare, ast.Compare):
continue
operands = [compare.left] + list(compare.comparators)
for i, op in enumerate(compare.ops):
if not isinstance(op, ast.Eq):
continue
left_name = _data_ptr_receiver(operands[i])
right_name = _data_ptr_receiver(operands[i + 1])
if left_name in params and right_name not in params:
compared_state_roots.add(right_name)
elif right_name in params and left_name not in params:
compared_state_roots.add(left_name)
replay_roots = returned_roots & compared_state_roots
if replay_roots:
state_root = next(iter(replay_roots))
break
if state_root is None:
continue

stores_state_tuple = any(
isinstance(stmt, ast.Assign)
and any(_ast_root_name(target) == state_root for target in stmt.targets)
and isinstance(stmt.value, (ast.Tuple, ast.List))
and any(_expr_names(elt) & params for elt in stmt.value.elts)
for stmt in ast.walk(method)
)
if stores_state_tuple:
return [{
"pattern": "POINTER_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name}.__call__ replays tuple-stored output "
"when saved tensor data_ptr matches input"
),
}]

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
Expand Down