Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,91 @@ def _looks_stateful_name(name: str) -> bool:
return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo"))


def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool:
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id in params
and not expr.args
and not expr.keywords
)


def _lambda_returns_input_float(expr: ast.AST | None) -> bool:
if not isinstance(expr, ast.Lambda):
return False
params = {arg.arg for arg in expr.args.args}
return _is_input_float_call(expr.body, params)


def _input_attr_float_storage(expr: ast.AST | None, owners: set[str]) -> Optional[str]:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "float"
and not expr.args
and not expr.keywords
and isinstance(expr.func.value, ast.Attribute)
and isinstance(expr.func.value.value, ast.Name)
and expr.func.value.value.id in owners
):
return None
return expr.func.value.attr


def _is_noop_context_call(expr: ast.AST | None) -> bool:
if not isinstance(expr, ast.Call):
return False
func = expr.func
if isinstance(func, ast.Name):
return func.id == "nullcontext"
return isinstance(func, ast.Attribute) and func.attr == "nullcontext"


def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool:
aliases: set[str] = set()
saw_return = False

for stmt in body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if isinstance(stmt, ast.With) and all(
_is_noop_context_call(item.context_expr)
for item in stmt.items
):
if _input_float_return_from_body(stmt.body, params):
saw_return = True
continue
return False
if (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and _is_input_float_call(stmt.value, params)
):
aliases.add(stmt.targets[0].id)
continue
if isinstance(stmt, ast.Return):
if _is_input_float_call(stmt.value, params):
saw_return = True
continue
if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases:
saw_return = True
continue
return False
return False

return saw_return


# ---------------------------------------------------------------------------
# Detectors
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1163,6 +1248,268 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]:
return matches


def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: entrypoint returns the input tensor cast to float as fake output."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)
callable_entrypoint_classes: set[str] = set()
dynamic_callable_classes: set[str] = set()
dynamic_callable_instances: set[str] = set()
subclass_hook_bases: set[str] = set()
subclass_entrypoint_classes: set[str] = set()
descriptor_classes: dict[str, str] = {}
wrapper_descriptor_attrs: dict[tuple[str, str], str] = {}

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = {arg.arg for arg in node.args.args}
if not params:
continue
if _input_float_return_from_body(node.body, params):
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} returns an input tensor cast to float without compute",
}]

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
target_names = {t.id for t in stmt.targets if isinstance(t, ast.Name)}
if (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "type"
and len(stmt.value.args) >= 3
and isinstance(stmt.value.args[2], ast.Dict)
):
for key, value in zip(stmt.value.args[2].keys, stmt.value.args[2].values):
if not (isinstance(key, ast.Constant) and key.value == "__call__"):
continue
if _lambda_returns_input_float(value):
dynamic_callable_classes.update(target_names)
continue
if (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in dynamic_callable_classes
):
dynamic_callable_instances.update(target_names)

for stmt in tree.body:
if not isinstance(stmt, ast.ClassDef):
continue
for child in stmt.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if child.name != "__init_subclass__":
continue
params = {arg.arg for arg in child.args.args}
for assign in ast.walk(child):
if not isinstance(assign, ast.Assign):
continue
if not _lambda_returns_input_float(assign.value):
continue
for target in assign.targets:
if not isinstance(target, ast.Attribute):
continue
if not is_entrypoint_name(target.attr):
continue
if isinstance(target.value, ast.Name) and target.value.id in params:
subclass_hook_bases.add(stmt.name)

for stmt in tree.body:
if not isinstance(stmt, ast.ClassDef):
continue
if any(
isinstance(base, ast.Name) and base.id in subclass_hook_bases
for base in stmt.bases
):
subclass_entrypoint_classes.add(stmt.name)

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets):
continue
if not (
isinstance(stmt.value, ast.Attribute)
and is_entrypoint_name(stmt.value.attr)
and isinstance(stmt.value.value, ast.Name)
and stmt.value.value.id in subclass_entrypoint_classes
):
continue
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} is installed by __init_subclass__ as input.float()",
}]

for stmt in tree.body:
if not isinstance(stmt, ast.ClassDef):
continue
for child in stmt.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if child.name != "__get__" or len(child.args.args) < 2:
continue
owner_name = child.args.args[1].arg
returns = [s for s in child.body if isinstance(s, ast.Return)]
if len(returns) != 1:
continue
storage_attr = _input_attr_float_storage(returns[0].value, {owner_name})
if storage_attr is not None:
descriptor_classes[stmt.name] = storage_attr

for stmt in tree.body:
if not isinstance(stmt, ast.ClassDef):
continue

descriptor_attrs: dict[str, str] = {}
for child in stmt.body:
if not isinstance(child, ast.Assign):
continue
if not (
isinstance(child.value, ast.Call)
and isinstance(child.value.func, ast.Name)
and child.value.func.id in descriptor_classes
):
continue
for target in child.targets:
if isinstance(target, ast.Name):
descriptor_attrs[target.id] = descriptor_classes[child.value.func.id]

if not descriptor_attrs:
continue

init_stores: set[str] = set()
for child in stmt.body:
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if child.name != "__init__" or not child.args.args:
continue
self_name = child.args.args[0].arg
value_params = {arg.arg for arg in child.args.args[1:]}
for assign in ast.walk(child):
if not isinstance(assign, ast.Assign):
continue
if not (isinstance(assign.value, ast.Name) and assign.value.id in value_params):
continue
for target in assign.targets:
if (
isinstance(target, ast.Attribute)
and isinstance(target.value, ast.Name)
and target.value.id == self_name
):
init_stores.add(target.attr)

for public_attr, storage_attr in descriptor_attrs.items():
if storage_attr in init_stores:
wrapper_descriptor_attrs[(stmt.name, public_attr)] = storage_attr

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}
for stmt in node.body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if not (
isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Attribute)
and isinstance(stmt.value.value, ast.Call)
and isinstance(stmt.value.value.func, ast.Name)
and (stmt.value.value.func.id, stmt.value.attr) in wrapper_descriptor_attrs
and len(stmt.value.value.args) == 1
and isinstance(stmt.value.value.args[0], ast.Name)
and stmt.value.value.args[0].id in params
and not stmt.value.value.keywords
):
break
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} returns a descriptor-backed input.float() value",
}]

for stmt in tree.body:
if isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name):
callable_entrypoint_classes.add(stmt.name)
continue
if not isinstance(stmt, ast.Assign):
continue
if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets):
continue
if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name):
callable_entrypoint_classes.add(stmt.value.func.id)

for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef) or node.name not in callable_entrypoint_classes:
continue
for stmt in node.body:
if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if stmt.name != "__call__":
continue
params = {arg.arg for arg in stmt.args.args}
if _input_float_return_from_body(stmt.body, params):
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"callable {entrypoint_name} returns an input tensor cast to float without compute",
}]

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}

for stmt in node.body:
if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)):
continue
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
):
continue
if not (
isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in dynamic_callable_instances
and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Name)
and stmt.value.args[0].id in params
and not stmt.value.keywords
):
break
return [{
"pattern": "INPUT_PASSTHROUGH_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} delegates to a dynamic callable that returns input.float()",
}]

return []


def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 11: torch.compile for pipeline graph caching."""
facts = ensure_submission_facts(code)
Expand Down Expand Up @@ -3511,6 +3858,10 @@ class RulePolicy:
"TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
),
"INPUT_PASSTHROUGH_OUTPUT": RulePolicy(
"INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (),
(), "keep",
),
"TORCH_COMPILE_CACHE": RulePolicy(
"TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -3767,6 +4118,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_cuda_graph_replay,
detect_silent_fallback,
detect_trivial_probe,
detect_input_passthrough_output,
detect_torch_compile_cache,
detect_cuda_graph_python,
detect_hardcoded_shapes,
Expand Down Expand Up @@ -3805,6 +4157,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("cuda_graph_replay", detect_cuda_graph_replay),
("silent_fallback", detect_silent_fallback),
("trivial_probe", detect_trivial_probe),
("input_passthrough_output", detect_input_passthrough_output),
("torch_compile_cache", detect_torch_compile_cache),
("cuda_graph_python", detect_cuda_graph_python),
("hardcoded_shapes", detect_hardcoded_shapes),
Expand Down