From e95ac6275715662ec164d1e98f311d75bfa7105c Mon Sep 17 00:00:00 2001 From: Prasanna Date: Sat, 2 May 2026 22:49:42 +0530 Subject: [PATCH] Detect helper guard state replay --- kernelguard.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/kernelguard.py b/kernelguard.py index f086847..f1f43fc 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2202,6 +2202,73 @@ def _has_ver(expr: ast.AST | None) -> bool: continue if not is_entrypoint_name(node.name): continue + params = {arg.arg for arg in node.args.args} + + helper_guard_roots: dict[str, str] = {} + for helper in node.body: + if not isinstance(helper, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for stmt in helper.body: + if not isinstance(stmt, ast.Return): + continue + for compare in ast.walk(stmt.value): + if not isinstance(compare, ast.Compare): + continue + operands = [compare.left] + list(compare.comparators) + for i, op in enumerate(compare.ops): + if not isinstance(op, (ast.Is, ast.Eq)): + continue + left = operands[i] + right = operands[i + 1] + if isinstance(right, ast.Constant) and right.value is None: + root = _ast_root_name(left) + elif isinstance(left, ast.Constant) and left.value is None: + root = _ast_root_name(right) + else: + root = None + if root is not None: + helper_guard_roots[helper.name] = root + + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + called_helpers = { + call.func.id + for call in ast.walk(child.test) + if isinstance(call, ast.Call) + and isinstance(call.func, ast.Name) + and call.func.id in helper_guard_roots + } + if not called_helpers: + continue + guarded_roots = {helper_guard_roots[name] for name in called_helpers} + assigned_from_input = { + _ast_root_name(target) + for stmt in child.body + if isinstance(stmt, ast.Assign) and (_expr_names(stmt.value) & params) + for target in stmt.targets + } - {None} + if not (guarded_roots & assigned_from_input): + continue + if child not in node.body: + continue + later = node.body[node.body.index(child) + 1:] + returned_after = { + _ast_root_name(stmt.value) + for stmt in later + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + replay_roots = guarded_roots & assigned_from_input & returned_after + if replay_roots: + root = next(iter(replay_roots)) + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns first-call output from " + f"state slot {root} behind input-independent helper guard" + ), + }] signature_features: dict[str, set[str]] = defaultdict(set) saved_state_features: dict[str, set[str]] = defaultdict(set)