diff --git a/kernelguard.py b/kernelguard.py index f086847..50629f0 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -553,6 +553,69 @@ def _looks_stateful_name(name: str) -> bool: return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo")) +def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in params + and not expr.args + and not expr.keywords + ) + + +def _is_noop_context_call(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Call): + return False + func = expr.func + if isinstance(func, ast.Name): + return func.id == "nullcontext" + return isinstance(func, ast.Attribute) and func.attr == "nullcontext" + + +def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if isinstance(stmt, ast.With) and all( + _is_noop_context_call(item.context_expr) + for item in stmt.items + ): + if _input_float_return_from_body(stmt.body, params): + saw_return = True + continue + return False + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_input_float_call(stmt.value, params) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_input_float_call(stmt.value, params): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + # --------------------------------------------------------------------------- # Detectors # --------------------------------------------------------------------------- @@ -1163,6 +1226,33 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches +def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns the input tensor cast to float as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + params = {arg.arg for arg in node.args.args} + if not params: + continue + if _input_float_return_from_body(node.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns an input tensor cast to float without compute", + }] + + return [] + + def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: """Pattern 11: torch.compile for pipeline graph caching.""" facts = ensure_submission_facts(code) @@ -3511,6 +3601,10 @@ class RulePolicy: "TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", ), + "INPUT_PASSTHROUGH_OUTPUT": RulePolicy( + "INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TORCH_COMPILE_CACHE": RulePolicy( "TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", @@ -3767,6 +3861,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_cuda_graph_replay, detect_silent_fallback, detect_trivial_probe, + detect_input_passthrough_output, detect_torch_compile_cache, detect_cuda_graph_python, detect_hardcoded_shapes, @@ -3805,6 +3900,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), ("trivial_probe", detect_trivial_probe), + ("input_passthrough_output", detect_input_passthrough_output), ("torch_compile_cache", detect_torch_compile_cache), ("cuda_graph_python", detect_cuda_graph_python), ("hardcoded_shapes", detect_hardcoded_shapes),