diff --git a/kernelguard.py b/kernelguard.py index f086847..7099732 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2053,6 +2053,73 @@ def _is_data_ptr_call(expr: ast.AST | None) -> bool: and expr.func.attr == "data_ptr" ) + def _data_ptr_receiver(expr: ast.AST | None) -> Optional[str]: + if not _is_data_ptr_call(expr): + return None + return _ast_root_name(expr.func.value) + + for class_node in ast.walk(tree): + if not isinstance(class_node, ast.ClassDef) or not is_entrypoint_name(class_node.name): + continue + methods = [ + stmt for stmt in class_node.body + if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and stmt.name == "__call__" + ] + for method in methods: + params = {arg.arg for arg in method.args.args} + params.discard("self") + state_root = None + for child in ast.walk(method): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + returned_roots = { + _ast_root_name(stmt.value) + for stmt in child.body + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + if not returned_roots: + continue + + compared_state_roots: set[str] = set() + for compare in ast.walk(child.test): + if not isinstance(compare, ast.Compare): + continue + operands = [compare.left] + list(compare.comparators) + for i, op in enumerate(compare.ops): + if not isinstance(op, ast.Eq): + continue + left_name = _data_ptr_receiver(operands[i]) + right_name = _data_ptr_receiver(operands[i + 1]) + if left_name in params and right_name not in params: + compared_state_roots.add(right_name) + elif right_name in params and left_name not in params: + compared_state_roots.add(left_name) + replay_roots = returned_roots & compared_state_roots + if replay_roots: + state_root = next(iter(replay_roots)) + break + if state_root is None: + continue + + stores_state_tuple = any( + isinstance(stmt, ast.Assign) + and any(_ast_root_name(target) == state_root for target in stmt.targets) + and isinstance(stmt.value, (ast.Tuple, ast.List)) + and any(_expr_names(elt) & params for elt in stmt.value.elts) + for stmt in ast.walk(method) + ) + if stores_state_tuple: + return [{ + "pattern": "POINTER_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name}.__call__ replays tuple-stored output " + "when saved tensor data_ptr matches input" + ), + }] + for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): continue