diff --git a/kernelguard.py b/kernelguard.py index f086847..0b5a726 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2202,6 +2202,65 @@ def _has_ver(expr: ast.AST | None) -> bool: continue if not is_entrypoint_name(node.name): continue + params = {arg.arg for arg in node.args.args} + + for loop in ast.walk(node): + if not isinstance(loop, ast.While): + continue + assigned_from_input = { + _ast_root_name(target) + for stmt in loop.body + if isinstance(stmt, ast.Assign) and (_expr_names(stmt.value) & params) + for target in stmt.targets + } - {None} + if not assigned_from_input: + continue + + guard_roots: set[str] = set() + for stmt in loop.body: + if not isinstance(stmt, ast.If): + continue + has_break = any(isinstance(inner, ast.Break) for inner in stmt.body) + if not has_break: + continue + for compare in ast.walk(stmt.test): + if not isinstance(compare, ast.Compare): + continue + operands = [compare.left] + list(compare.comparators) + for i, op in enumerate(compare.ops): + if not isinstance(op, (ast.IsNot, ast.NotEq)): + continue + left = operands[i] + right = operands[i + 1] + if isinstance(right, ast.Constant) and right.value is None: + root = _ast_root_name(left) + elif isinstance(left, ast.Constant) and left.value is None: + root = _ast_root_name(right) + else: + root = None + if root is not None: + guard_roots.add(root) + replay_roots = assigned_from_input & guard_roots + if not replay_roots or loop not in node.body: + continue + + later = node.body[node.body.index(loop) + 1:] + returned_after = { + _ast_root_name(stmt.value) + for stmt in later + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + replay_roots &= returned_after + if replay_roots: + root = next(iter(replay_roots)) + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns first-call output from " + f"state slot {root} initialized inside replay loop" + ), + }] signature_features: dict[str, set[str]] = defaultdict(set) saved_state_features: dict[str, set[str]] = defaultdict(set)