diff --git a/kernelguard.py b/kernelguard.py index f086847..cb9320d 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2202,6 +2202,57 @@ def _has_ver(expr: ast.AST | None) -> bool: continue if not is_entrypoint_name(node.name): continue + params = {arg.arg for arg in node.args.args} + + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + alias_to_state: dict[str, str] = {} + for named in ast.walk(child.test): + if not isinstance(named, ast.NamedExpr): + continue + if not isinstance(named.target, ast.Name): + continue + state_name = _ast_root_name(named.value) + if state_name is not None: + alias_to_state[named.target.id] = state_name + if not alias_to_state: + continue + + returned_aliases = { + stmt.value.id + for stmt in child.orelse + if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name) + } + replay_states = { + alias_to_state[name] + for name in returned_aliases + if name in alias_to_state + } + if not replay_states: + continue + + assigned_from_input = { + _ast_root_name(target) + for stmt in child.body + if isinstance(stmt, ast.Assign) and (_expr_names(stmt.value) & params) + for target in stmt.targets + } - {None} + returned_state = { + _ast_root_name(stmt.value) + for stmt in child.body + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + if replay_states & assigned_from_input & returned_state: + state_name = next(iter(replay_states & assigned_from_input & returned_state)) + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} replays first-call state through " + f"assignment-expression alias of {state_name}" + ), + }] signature_features: dict[str, set[str]] = defaultdict(set) saved_state_features: dict[str, set[str]] = defaultdict(set)