Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,69 @@ def _looks_stateful_name(name: str) -> bool:
return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo"))


def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool:
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id in params
and not expr.args
and not expr.keywords
)


def _is_noop_context_call(expr: ast.AST | None) -> bool:
if not isinstance(expr, ast.Call):
return False
func = expr.func
if isinstance(func, ast.Name):
return func.id == "nullcontext"
return isinstance(func, ast.Attribute) and func.attr == "nullcontext"


def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool:
aliases: set[str] = set()
saw_return = False

for stmt in body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if isinstance(stmt, ast.With) and all(
_is_noop_context_call(item.context_expr)
for item in stmt.items
):
if _input_float_return_from_body(stmt.body, params):
saw_return = True
continue
return False
if (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and _is_input_float_call(stmt.value, params)
):
aliases.add(stmt.targets[0].id)
continue
if isinstance(stmt, ast.Return):
if _is_input_float_call(stmt.value, params):
saw_return = True
continue
if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases:
saw_return = True
continue
return False
return False

return saw_return


# ---------------------------------------------------------------------------
# Detectors
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1163,6 +1226,33 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]:
return matches


def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: entrypoint returns the input tensor cast to float as fake output."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = {arg.arg for arg in node.args.args}
if not params:
continue
if _input_float_return_from_body(node.body, params):
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} returns an input tensor cast to float without compute",
}]

return []


def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 11: torch.compile for pipeline graph caching."""
facts = ensure_submission_facts(code)
Expand Down Expand Up @@ -3511,6 +3601,10 @@ class RulePolicy:
"TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
),
"INPUT_PASSTHROUGH_OUTPUT": RulePolicy(
"INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (),
(), "keep",
),
"TORCH_COMPILE_CACHE": RulePolicy(
"TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -3767,6 +3861,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_cuda_graph_replay,
detect_silent_fallback,
detect_trivial_probe,
detect_input_passthrough_output,
detect_torch_compile_cache,
detect_cuda_graph_python,
detect_hardcoded_shapes,
Expand Down Expand Up @@ -3805,6 +3900,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("cuda_graph_replay", detect_cuda_graph_replay),
("silent_fallback", detect_silent_fallback),
("trivial_probe", detect_trivial_probe),
("input_passthrough_output", detect_input_passthrough_output),
("torch_compile_cache", detect_torch_compile_cache),
("cuda_graph_python", detect_cuda_graph_python),
("hardcoded_shapes", detect_hardcoded_shapes),
Expand Down