Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,76 @@ def _looks_stateful_name(name: str) -> bool:
return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo"))


def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool:
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id in params
and not expr.args
and not expr.keywords
)


def _lambda_returns_input_float(expr: ast.AST | None) -> bool:
if not isinstance(expr, ast.Lambda):
return False
params = {arg.arg for arg in expr.args.args}
return _is_input_float_call(expr.body, params)


def _is_noop_context_call(expr: ast.AST | None) -> bool:
if not isinstance(expr, ast.Call):
return False
func = expr.func
if isinstance(func, ast.Name):
return func.id == "nullcontext"
return isinstance(func, ast.Attribute) and func.attr == "nullcontext"


def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool:
aliases: set[str] = set()
saw_return = False

for stmt in body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if isinstance(stmt, ast.With) and all(
_is_noop_context_call(item.context_expr)
for item in stmt.items
):
if _input_float_return_from_body(stmt.body, params):
saw_return = True
continue
return False
if (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and _is_input_float_call(stmt.value, params)
):
aliases.add(stmt.targets[0].id)
continue
if isinstance(stmt, ast.Return):
if _is_input_float_call(stmt.value, params):
saw_return = True
continue
if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases:
saw_return = True
continue
return False
return False

return saw_return


# ---------------------------------------------------------------------------
# Detectors
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1163,6 +1233,120 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]:
return matches


def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: entrypoint returns the input tensor cast to float as fake output."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)
callable_entrypoint_classes: set[str] = set()
dynamic_callable_classes: set[str] = set()
dynamic_callable_instances: set[str] = set()

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = {arg.arg for arg in node.args.args}
if not params:
continue
if _input_float_return_from_body(node.body, params):
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} returns an input tensor cast to float without compute",
}]

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
target_names = {t.id for t in stmt.targets if isinstance(t, ast.Name)}
if (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "type"
and len(stmt.value.args) >= 3
and isinstance(stmt.value.args[2], ast.Dict)
):
for key, value in zip(stmt.value.args[2].keys, stmt.value.args[2].values):
if not (isinstance(key, ast.Constant) and key.value == "__call__"):
continue
if _lambda_returns_input_float(value):
dynamic_callable_classes.update(target_names)
continue
if (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in dynamic_callable_classes
):
dynamic_callable_instances.update(target_names)

for stmt in tree.body:
if isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name):
callable_entrypoint_classes.add(stmt.name)
continue
if not isinstance(stmt, ast.Assign):
continue
if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets):
continue
if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name):
callable_entrypoint_classes.add(stmt.value.func.id)

for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef) or node.name not in callable_entrypoint_classes:
continue
for stmt in node.body:
if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if stmt.name != "__call__":
continue
params = {arg.arg for arg in stmt.args.args}
if _input_float_return_from_body(stmt.body, params):
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"callable {entrypoint_name} returns an input tensor cast to float without compute",
}]

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}

for stmt in node.body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if not (
isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in dynamic_callable_instances
and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Name)
and stmt.value.args[0].id in params
and not stmt.value.keywords
):
break
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} delegates to a dynamic callable that returns input.float()",
}]

return []


def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 11: torch.compile for pipeline graph caching."""
facts = ensure_submission_facts(code)
Expand Down Expand Up @@ -3511,6 +3695,10 @@ class RulePolicy:
"TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
),
"INPUT_PASSTHROUGH_OUTPUT": RulePolicy(
"INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (),
(), "keep",
),
"TORCH_COMPILE_CACHE": RulePolicy(
"TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -3767,6 +3955,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_cuda_graph_replay,
detect_silent_fallback,
detect_trivial_probe,
detect_input_passthrough_output,
detect_torch_compile_cache,
detect_cuda_graph_python,
detect_hardcoded_shapes,
Expand Down Expand Up @@ -3805,6 +3994,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("cuda_graph_replay", detect_cuda_graph_replay),
("silent_fallback", detect_silent_fallback),
("trivial_probe", detect_trivial_probe),
("input_passthrough_output", detect_input_passthrough_output),
("torch_compile_cache", detect_torch_compile_cache),
("cuda_graph_python", detect_cuda_graph_python),
("hardcoded_shapes", detect_hardcoded_shapes),
Expand Down