Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,91 @@ def detect_last_call_replay(code: str | SubmissionFacts) -> list[dict]:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

replay_generators: set[str] = set()
for gen in ast.walk(tree):
if not isinstance(gen, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not any(isinstance(child, (ast.Yield, ast.YieldFrom)) for child in ast.walk(gen)):
continue

yielded_input_names = {
target.id
for stmt in ast.walk(gen)
if isinstance(stmt, ast.Assign)
for target in stmt.targets
if isinstance(target, ast.Name) and isinstance(stmt.value, ast.Yield)
}
if not yielded_input_names:
continue

for child in ast.walk(gen):
if not isinstance(child, ast.If):
continue
identity_pairs: list[tuple[str, str]] = []
for compare in ast.walk(child.test):
if not isinstance(compare, ast.Compare):
continue
operands = [compare.left] + list(compare.comparators)
for i, op in enumerate(compare.ops):
if not isinstance(op, ast.Is):
continue
left = _ast_root_name(operands[i])
right = _ast_root_name(operands[i + 1])
if left in yielded_input_names and right:
identity_pairs.append((left, right))
elif right in yielded_input_names and left:
identity_pairs.append((right, left))
if not identity_pairs:
continue

yielded_fast_names = {
_ast_root_name(expr.value)
for expr in ast.walk(ast.Module(body=child.body, type_ignores=[]))
if isinstance(expr, ast.Yield)
} - {None}
assigned_after_fast: set[str] = set()
for stmt in child.orelse:
if not isinstance(stmt, ast.Assign):
continue
targets = stmt.targets
values = [stmt.value] * len(targets)
if len(targets) == 1 and isinstance(targets[0], ast.Tuple) and isinstance(stmt.value, ast.Tuple):
targets = list(targets[0].elts)
values = list(stmt.value.elts)
for target, value in zip(targets, values):
if _expr_names(value) & yielded_input_names:
name = _ast_root_name(target)
if name:
assigned_after_fast.add(name)

if yielded_fast_names & assigned_after_fast:
replay_generators.add(gen.name)
break

if replay_generators:
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}
sends_input = any(
isinstance(call, ast.Call)
and isinstance(call.func, ast.Attribute)
and call.func.attr == "send"
and any(_expr_names(arg) & params for arg in call.args)
for call in ast.walk(node)
)
if sends_input:
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} sends input into generator with "
"identity-guarded output replay"
),
}]

# Use pre-computed indices from build_submission_facts
none_inited = facts._none_inited
data_ptr_helpers = facts._data_ptr_helpers
Expand Down