diff --git a/kernelguard.py b/kernelguard.py index f086847..1ac48bc 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -553,6 +553,91 @@ def _looks_stateful_name(name: str) -> bool: return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo")) +def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in params + and not expr.args + and not expr.keywords + ) + + +def _lambda_returns_input_float(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Lambda): + return False + params = {arg.arg for arg in expr.args.args} + return _is_input_float_call(expr.body, params) + + +def _input_attr_float_storage(expr: ast.AST | None, owners: set[str]) -> Optional[str]: + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and not expr.args + and not expr.keywords + and isinstance(expr.func.value, ast.Attribute) + and isinstance(expr.func.value.value, ast.Name) + and expr.func.value.value.id in owners + ): + return None + return expr.func.value.attr + + +def _is_noop_context_call(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Call): + return False + func = expr.func + if isinstance(func, ast.Name): + return func.id == "nullcontext" + return isinstance(func, ast.Attribute) and func.attr == "nullcontext" + + +def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if isinstance(stmt, ast.With) and all( + _is_noop_context_call(item.context_expr) + for item in stmt.items + ): + if _input_float_return_from_body(stmt.body, params): + saw_return = True + continue + return False + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_input_float_call(stmt.value, params) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_input_float_call(stmt.value, params): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + # --------------------------------------------------------------------------- # Detectors # --------------------------------------------------------------------------- @@ -1163,6 +1248,268 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches +def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns the input tensor cast to float as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + callable_entrypoint_classes: set[str] = set() + dynamic_callable_classes: set[str] = set() + dynamic_callable_instances: set[str] = set() + subclass_hook_bases: set[str] = set() + subclass_entrypoint_classes: set[str] = set() + descriptor_classes: dict[str, str] = {} + wrapper_descriptor_attrs: dict[tuple[str, str], str] = {} + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + params = {arg.arg for arg in node.args.args} + if not params: + continue + if _input_float_return_from_body(node.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns an input tensor cast to float without compute", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + target_names = {t.id for t in stmt.targets if isinstance(t, ast.Name)} + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id == "type" + and len(stmt.value.args) >= 3 + and isinstance(stmt.value.args[2], ast.Dict) + ): + for key, value in zip(stmt.value.args[2].keys, stmt.value.args[2].values): + if not (isinstance(key, ast.Constant) and key.value == "__call__"): + continue + if _lambda_returns_input_float(value): + dynamic_callable_classes.update(target_names) + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_classes + ): + dynamic_callable_instances.update(target_names) + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__init_subclass__": + continue + params = {arg.arg for arg in child.args.args} + for assign in ast.walk(child): + if not isinstance(assign, ast.Assign): + continue + if not _lambda_returns_input_float(assign.value): + continue + for target in assign.targets: + if not isinstance(target, ast.Attribute): + continue + if not is_entrypoint_name(target.attr): + continue + if isinstance(target.value, ast.Name) and target.value.id in params: + subclass_hook_bases.add(stmt.name) + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + if any( + isinstance(base, ast.Name) and base.id in subclass_hook_bases + for base in stmt.bases + ): + subclass_entrypoint_classes.add(stmt.name) + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + if not ( + isinstance(stmt.value, ast.Attribute) + and is_entrypoint_name(stmt.value.attr) + and isinstance(stmt.value.value, ast.Name) + and stmt.value.value.id in subclass_entrypoint_classes + ): + continue + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} is installed by __init_subclass__ as input.float()", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__get__" or len(child.args.args) < 2: + continue + owner_name = child.args.args[1].arg + returns = [s for s in child.body if isinstance(s, ast.Return)] + if len(returns) != 1: + continue + storage_attr = _input_attr_float_storage(returns[0].value, {owner_name}) + if storage_attr is not None: + descriptor_classes[stmt.name] = storage_attr + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + + descriptor_attrs: dict[str, str] = {} + for child in stmt.body: + if not isinstance(child, ast.Assign): + continue + if not ( + isinstance(child.value, ast.Call) + and isinstance(child.value.func, ast.Name) + and child.value.func.id in descriptor_classes + ): + continue + for target in child.targets: + if isinstance(target, ast.Name): + descriptor_attrs[target.id] = descriptor_classes[child.value.func.id] + + if not descriptor_attrs: + continue + + init_stores: set[str] = set() + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__init__" or not child.args.args: + continue + self_name = child.args.args[0].arg + value_params = {arg.arg for arg in child.args.args[1:]} + for assign in ast.walk(child): + if not isinstance(assign, ast.Assign): + continue + if not (isinstance(assign.value, ast.Name) and assign.value.id in value_params): + continue + for target in assign.targets: + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == self_name + ): + init_stores.add(target.attr) + + for public_attr, storage_attr in descriptor_attrs.items(): + if storage_attr in init_stores: + wrapper_descriptor_attrs[(stmt.name, public_attr)] = storage_attr + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + for stmt in node.body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Attribute) + and isinstance(stmt.value.value, ast.Call) + and isinstance(stmt.value.value.func, ast.Name) + and (stmt.value.value.func.id, stmt.value.attr) in wrapper_descriptor_attrs + and len(stmt.value.value.args) == 1 + and isinstance(stmt.value.value.args[0], ast.Name) + and stmt.value.value.args[0].id in params + and not stmt.value.value.keywords + ): + break + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns a descriptor-backed input.float() value", + }] + + for stmt in tree.body: + if isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name): + callable_entrypoint_classes.add(stmt.name) + continue + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + callable_entrypoint_classes.add(stmt.value.func.id) + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef) or node.name not in callable_entrypoint_classes: + continue + for stmt in node.body: + if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if stmt.name != "__call__": + continue + params = {arg.arg for arg in stmt.args.args} + if _input_float_return_from_body(stmt.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"callable {entrypoint_name} returns an input tensor cast to float without compute", + }] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + + for stmt in node.body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_instances + and len(stmt.value.args) == 1 + and isinstance(stmt.value.args[0], ast.Name) + and stmt.value.args[0].id in params + and not stmt.value.keywords + ): + break + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} delegates to a dynamic callable that returns input.float()", + }] + + return [] + + def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: """Pattern 11: torch.compile for pipeline graph caching.""" facts = ensure_submission_facts(code) @@ -3511,6 +3858,10 @@ class RulePolicy: "TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", ), + "INPUT_PASSTHROUGH_OUTPUT": RulePolicy( + "INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TORCH_COMPILE_CACHE": RulePolicy( "TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", @@ -3767,6 +4118,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_cuda_graph_replay, detect_silent_fallback, detect_trivial_probe, + detect_input_passthrough_output, detect_torch_compile_cache, detect_cuda_graph_python, detect_hardcoded_shapes, @@ -3805,6 +4157,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), ("trivial_probe", detect_trivial_probe), + ("input_passthrough_output", detect_input_passthrough_output), ("torch_compile_cache", detect_torch_compile_cache), ("cuda_graph_python", detect_cuda_graph_python), ("hardcoded_shapes", detect_hardcoded_shapes),