diff --git a/kernelguard.py b/kernelguard.py index f086847..29ac17c 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -553,6 +553,76 @@ def _looks_stateful_name(name: str) -> bool: return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo")) +def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in params + and not expr.args + and not expr.keywords + ) + + +def _lambda_returns_input_float(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Lambda): + return False + params = {arg.arg for arg in expr.args.args} + return _is_input_float_call(expr.body, params) + + +def _is_noop_context_call(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Call): + return False + func = expr.func + if isinstance(func, ast.Name): + return func.id == "nullcontext" + return isinstance(func, ast.Attribute) and func.attr == "nullcontext" + + +def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if isinstance(stmt, ast.With) and all( + _is_noop_context_call(item.context_expr) + for item in stmt.items + ): + if _input_float_return_from_body(stmt.body, params): + saw_return = True + continue + return False + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_input_float_call(stmt.value, params) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_input_float_call(stmt.value, params): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + # --------------------------------------------------------------------------- # Detectors # --------------------------------------------------------------------------- @@ -1163,6 +1233,120 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches +def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns the input tensor cast to float as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + callable_entrypoint_classes: set[str] = set() + dynamic_callable_classes: set[str] = set() + dynamic_callable_instances: set[str] = set() + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + params = {arg.arg for arg in node.args.args} + if not params: + continue + if _input_float_return_from_body(node.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns an input tensor cast to float without compute", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + target_names = {t.id for t in stmt.targets if isinstance(t, ast.Name)} + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id == "type" + and len(stmt.value.args) >= 3 + and isinstance(stmt.value.args[2], ast.Dict) + ): + for key, value in zip(stmt.value.args[2].keys, stmt.value.args[2].values): + if not (isinstance(key, ast.Constant) and key.value == "__call__"): + continue + if _lambda_returns_input_float(value): + dynamic_callable_classes.update(target_names) + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_classes + ): + dynamic_callable_instances.update(target_names) + + for stmt in tree.body: + if isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name): + callable_entrypoint_classes.add(stmt.name) + continue + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + callable_entrypoint_classes.add(stmt.value.func.id) + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef) or node.name not in callable_entrypoint_classes: + continue + for stmt in node.body: + if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if stmt.name != "__call__": + continue + params = {arg.arg for arg in stmt.args.args} + if _input_float_return_from_body(stmt.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"callable {entrypoint_name} returns an input tensor cast to float without compute", + }] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + + for stmt in node.body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_instances + and len(stmt.value.args) == 1 + and isinstance(stmt.value.args[0], ast.Name) + and stmt.value.args[0].id in params + and not stmt.value.keywords + ): + break + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} delegates to a dynamic callable that returns input.float()", + }] + + return [] + + def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: """Pattern 11: torch.compile for pipeline graph caching.""" facts = ensure_submission_facts(code) @@ -3511,6 +3695,10 @@ class RulePolicy: "TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", ), + "INPUT_PASSTHROUGH_OUTPUT": RulePolicy( + "INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TORCH_COMPILE_CACHE": RulePolicy( "TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", @@ -3767,6 +3955,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_cuda_graph_replay, detect_silent_fallback, detect_trivial_probe, + detect_input_passthrough_output, detect_torch_compile_cache, detect_cuda_graph_python, detect_hardcoded_shapes, @@ -3805,6 +3994,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), ("trivial_probe", detect_trivial_probe), + ("input_passthrough_output", detect_input_passthrough_output), ("torch_compile_cache", detect_torch_compile_cache), ("cuda_graph_python", detect_cuda_graph_python), ("hardcoded_shapes", detect_hardcoded_shapes),