diff --git a/kernelguard.py b/kernelguard.py index f086847..6ae35ea 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2168,6 +2168,91 @@ def detect_last_call_replay(code: str | SubmissionFacts) -> list[dict]: return [] entrypoint_name = entrypoint_label(facts.entrypoint_name) + replay_generators: set[str] = set() + for gen in ast.walk(tree): + if not isinstance(gen, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not any(isinstance(child, (ast.Yield, ast.YieldFrom)) for child in ast.walk(gen)): + continue + + yielded_input_names = { + target.id + for stmt in ast.walk(gen) + if isinstance(stmt, ast.Assign) + for target in stmt.targets + if isinstance(target, ast.Name) and isinstance(stmt.value, ast.Yield) + } + if not yielded_input_names: + continue + + for child in ast.walk(gen): + if not isinstance(child, ast.If): + continue + identity_pairs: list[tuple[str, str]] = [] + for compare in ast.walk(child.test): + if not isinstance(compare, ast.Compare): + continue + operands = [compare.left] + list(compare.comparators) + for i, op in enumerate(compare.ops): + if not isinstance(op, ast.Is): + continue + left = _ast_root_name(operands[i]) + right = _ast_root_name(operands[i + 1]) + if left in yielded_input_names and right: + identity_pairs.append((left, right)) + elif right in yielded_input_names and left: + identity_pairs.append((right, left)) + if not identity_pairs: + continue + + yielded_fast_names = { + _ast_root_name(expr.value) + for expr in ast.walk(ast.Module(body=child.body, type_ignores=[])) + if isinstance(expr, ast.Yield) + } - {None} + assigned_after_fast: set[str] = set() + for stmt in child.orelse: + if not isinstance(stmt, ast.Assign): + continue + targets = stmt.targets + values = [stmt.value] * len(targets) + if len(targets) == 1 and isinstance(targets[0], ast.Tuple) and isinstance(stmt.value, ast.Tuple): + targets = list(targets[0].elts) + values = list(stmt.value.elts) + for target, value in zip(targets, values): + if _expr_names(value) & yielded_input_names: + name = _ast_root_name(target) + if name: + assigned_after_fast.add(name) + + if yielded_fast_names & assigned_after_fast: + replay_generators.add(gen.name) + break + + if replay_generators: + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + sends_input = any( + isinstance(call, ast.Call) + and isinstance(call.func, ast.Attribute) + and call.func.attr == "send" + and any(_expr_names(arg) & params for arg in call.args) + for call in ast.walk(node) + ) + if sends_input: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} sends input into generator with " + "identity-guarded output replay" + ), + }] + # Use pre-computed indices from build_submission_facts none_inited = facts._none_inited data_ptr_helpers = facts._data_ptr_helpers