diff --git a/kernelguard.py b/kernelguard.py index f086847..bc73328 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -553,6 +553,232 @@ def _looks_stateful_name(name: str) -> bool: return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo")) +def _is_input_float_call(expr: ast.AST | None, params: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in params + and not expr.args + and not expr.keywords + ) + + +def _is_input_reduction_call(expr: ast.AST | None, params: set[str], torch_aliases: set[str]) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr in {"any", "all"} + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in torch_aliases + and len(expr.args) == 1 + and isinstance(expr.args[0], ast.Name) + and expr.args[0].id in params + and not expr.keywords + ) + + +def _is_param_t_call(expr: ast.AST | None, param: str) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "t" + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id == param + and not expr.args + and not expr.keywords + ) + + +def _is_self_matmul_expr( + expr: ast.AST | None, + params: set[str], + torch_aliases: Optional[set[str]] = None, + mm_aliases: Optional[set[str]] = None, +) -> bool: + torch_aliases = torch_aliases or set() + mm_aliases = mm_aliases or set() + if isinstance(expr, ast.BinOp) and isinstance(expr.op, ast.MatMult): + return any( + isinstance(expr.left, ast.Name) + and expr.left.id == param + and _is_param_t_call(expr.right, param) + for param in params + ) + if not isinstance(expr, ast.Call): + return False + if len(expr.args) != 2 or expr.keywords: + return False + for param in params: + if not ( + isinstance(expr.args[0], ast.Name) + and expr.args[0].id == param + and _is_param_t_call(expr.args[1], param) + ): + continue + if isinstance(expr.func, ast.Name) and expr.func.id in mm_aliases: + return True + if ( + isinstance(expr.func, ast.Attribute) + and expr.func.attr in {"mm", "matmul"} + and isinstance(expr.func.value, ast.Name) + and expr.func.value.id in torch_aliases + ): + return True + return False + + +def _lambda_returns_input_float(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Lambda): + return False + params = {arg.arg for arg in expr.args.args} + return _is_input_float_call(expr.body, params) + + +def _input_attr_float_storage(expr: ast.AST | None, owners: set[str]) -> Optional[str]: + if not ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "float" + and not expr.args + and not expr.keywords + and isinstance(expr.func.value, ast.Attribute) + and isinstance(expr.func.value.value, ast.Name) + and expr.func.value.value.id in owners + ): + return None + return expr.func.value.attr + + +def _is_noop_context_call(expr: ast.AST | None) -> bool: + if not isinstance(expr, ast.Call): + return False + func = expr.func + if isinstance(func, ast.Name): + return func.id == "nullcontext" + return isinstance(func, ast.Attribute) and func.attr == "nullcontext" + + +def _input_float_return_from_body(body: list[ast.stmt], params: set[str]) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if isinstance(stmt, ast.With) and all( + _is_noop_context_call(item.context_expr) + for item in stmt.items + ): + if _input_float_return_from_body(stmt.body, params): + saw_return = True + continue + return False + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_input_float_call(stmt.value, params) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_input_float_call(stmt.value, params): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + +def _input_reduction_return_from_body( + body: list[ast.stmt], + params: set[str], + torch_aliases: set[str], +) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_input_reduction_call(stmt.value, params, torch_aliases) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_input_reduction_call(stmt.value, params, torch_aliases): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + +def _self_matmul_return_from_body( + body: list[ast.stmt], + params: set[str], + torch_aliases: Optional[set[str]] = None, + mm_aliases: Optional[set[str]] = None, +) -> bool: + aliases: set[str] = set() + saw_return = False + + for stmt in body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and _is_self_matmul_expr(stmt.value, params, torch_aliases, mm_aliases) + ): + aliases.add(stmt.targets[0].id) + continue + if isinstance(stmt, ast.Return): + if _is_self_matmul_expr(stmt.value, params, torch_aliases, mm_aliases): + saw_return = True + continue + if isinstance(stmt.value, ast.Name) and stmt.value.id in aliases: + saw_return = True + continue + return False + return False + + return saw_return + + # --------------------------------------------------------------------------- # Detectors # --------------------------------------------------------------------------- @@ -1163,6 +1389,381 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches +def detect_input_passthrough_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns the input tensor cast to float as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + callable_entrypoint_classes: set[str] = set() + dynamic_callable_classes: set[str] = set() + dynamic_callable_instances: set[str] = set() + subclass_hook_bases: set[str] = set() + subclass_entrypoint_classes: set[str] = set() + descriptor_classes: dict[str, str] = {} + wrapper_descriptor_attrs: dict[tuple[str, str], str] = {} + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + + params = {arg.arg for arg in node.args.args} + if not params: + continue + if _input_float_return_from_body(node.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns an input tensor cast to float without compute", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + target_names = {t.id for t in stmt.targets if isinstance(t, ast.Name)} + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id == "type" + and len(stmt.value.args) >= 3 + and isinstance(stmt.value.args[2], ast.Dict) + ): + for key, value in zip(stmt.value.args[2].keys, stmt.value.args[2].values): + if not (isinstance(key, ast.Constant) and key.value == "__call__"): + continue + if _lambda_returns_input_float(value): + dynamic_callable_classes.update(target_names) + continue + if ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_classes + ): + dynamic_callable_instances.update(target_names) + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__init_subclass__": + continue + params = {arg.arg for arg in child.args.args} + for assign in ast.walk(child): + if not isinstance(assign, ast.Assign): + continue + if not _lambda_returns_input_float(assign.value): + continue + for target in assign.targets: + if not isinstance(target, ast.Attribute): + continue + if not is_entrypoint_name(target.attr): + continue + if isinstance(target.value, ast.Name) and target.value.id in params: + subclass_hook_bases.add(stmt.name) + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + if any( + isinstance(base, ast.Name) and base.id in subclass_hook_bases + for base in stmt.bases + ): + subclass_entrypoint_classes.add(stmt.name) + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + if not ( + isinstance(stmt.value, ast.Attribute) + and is_entrypoint_name(stmt.value.attr) + and isinstance(stmt.value.value, ast.Name) + and stmt.value.value.id in subclass_entrypoint_classes + ): + continue + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} is installed by __init_subclass__ as input.float()", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__get__" or len(child.args.args) < 2: + continue + owner_name = child.args.args[1].arg + returns = [s for s in child.body if isinstance(s, ast.Return)] + if len(returns) != 1: + continue + storage_attr = _input_attr_float_storage(returns[0].value, {owner_name}) + if storage_attr is not None: + descriptor_classes[stmt.name] = storage_attr + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + + descriptor_attrs: dict[str, str] = {} + for child in stmt.body: + if not isinstance(child, ast.Assign): + continue + if not ( + isinstance(child.value, ast.Call) + and isinstance(child.value.func, ast.Name) + and child.value.func.id in descriptor_classes + ): + continue + for target in child.targets: + if isinstance(target, ast.Name): + descriptor_attrs[target.id] = descriptor_classes[child.value.func.id] + + if not descriptor_attrs: + continue + + init_stores: set[str] = set() + for child in stmt.body: + if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if child.name != "__init__" or not child.args.args: + continue + self_name = child.args.args[0].arg + value_params = {arg.arg for arg in child.args.args[1:]} + for assign in ast.walk(child): + if not isinstance(assign, ast.Assign): + continue + if not (isinstance(assign.value, ast.Name) and assign.value.id in value_params): + continue + for target in assign.targets: + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == self_name + ): + init_stores.add(target.attr) + + for public_attr, storage_attr in descriptor_attrs.items(): + if storage_attr in init_stores: + wrapper_descriptor_attrs[(stmt.name, public_attr)] = storage_attr + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + for stmt in node.body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Attribute) + and isinstance(stmt.value.value, ast.Call) + and isinstance(stmt.value.value.func, ast.Name) + and (stmt.value.value.func.id, stmt.value.attr) in wrapper_descriptor_attrs + and len(stmt.value.value.args) == 1 + and isinstance(stmt.value.value.args[0], ast.Name) + and stmt.value.value.args[0].id in params + and not stmt.value.value.keywords + ): + break + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns a descriptor-backed input.float() value", + }] + + for stmt in tree.body: + if isinstance(stmt, ast.ClassDef) and is_entrypoint_name(stmt.name): + callable_entrypoint_classes.add(stmt.name) + continue + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + if isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name): + callable_entrypoint_classes.add(stmt.value.func.id) + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef) or node.name not in callable_entrypoint_classes: + continue + for stmt in node.body: + if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if stmt.name != "__call__": + continue + params = {arg.arg for arg in stmt.args.args} + if _input_float_return_from_body(stmt.body, params): + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"callable {entrypoint_name} returns an input tensor cast to float without compute", + }] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + + for stmt in node.body: + if isinstance(stmt, (ast.Import, ast.ImportFrom, ast.Pass)): + continue + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + continue + if not ( + isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id in dynamic_callable_instances + and len(stmt.value.args) == 1 + and isinstance(stmt.value.args[0], ast.Name) + and stmt.value.args[0].id in params + and not stmt.value.keywords + ): + break + return [{ + "pattern": "INPUT_PASSTHROUGH_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} delegates to a dynamic callable that returns input.float()", + }] + + return [] + + +def detect_input_reduction_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: entrypoint returns torch.any/all(input) as fake output.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + torch_aliases = {"torch"} + + for node in facts._imports: + for alias in node.names: + if alias.name == "torch": + torch_aliases.add(alias.asname or alias.name) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + if params and _input_reduction_return_from_body(node.body, params, torch_aliases): + return [{ + "pattern": "INPUT_REDUCTION_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns torch.any/all(input) without compute", + }] + + return [] + + +def detect_partial_self_matmul_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: custom_kernel partial wraps helper returning data @ data.t().""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + helper_names: set[str] = set() + partial_aliases = {"partial"} + functools_aliases = {"functools"} + torch_aliases = {"torch"} + mm_aliases: set[str] = set() + + for node in facts._imports: + for alias in node.names: + if alias.name == "functools": + functools_aliases.add(alias.asname or alias.name) + if alias.name == "torch": + torch_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module == "functools": + for alias in node.names: + if alias.name == "partial": + partial_aliases.add(alias.asname or alias.name) + if node.module == "torch": + for alias in node.names: + if alias.name in {"mm", "matmul"}: + mm_aliases.add(alias.asname or alias.name) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + params = {arg.arg for arg in node.args.args} + if params and _self_matmul_return_from_body(node.body, params, torch_aliases, mm_aliases): + helper_names.add(node.name) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + if params and _self_matmul_return_from_body(node.body, params, torch_aliases, mm_aliases): + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": "custom_kernel returns data @ data.t() or equivalent self-matmul", + }] + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(t, ast.Name) and is_entrypoint_name(t.id) for t in stmt.targets): + continue + value = stmt.value + if not ( + isinstance(value, ast.Call) + and len(value.args) == 1 + and isinstance(value.args[0], ast.Name) + and value.args[0].id in helper_names + and not value.keywords + ): + continue + if isinstance(value.func, ast.Name) and value.func.id in partial_aliases: + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": "custom_kernel partial wraps helper returning data @ data.t()", + }] + if ( + isinstance(value.func, ast.Attribute) + and value.func.attr == "partial" + and isinstance(value.func.value, ast.Name) + and value.func.value.id in functools_aliases + ): + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": "custom_kernel functools.partial wraps helper returning data @ data.t()", + }] + + return [] + + def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: """Pattern 11: torch.compile for pipeline graph caching.""" facts = ensure_submission_facts(code) @@ -3511,6 +4112,18 @@ class RulePolicy: "TRIVIAL_PROBE", "low_signal", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", ), + "INPUT_PASSTHROUGH_OUTPUT": RulePolicy( + "INPUT_PASSTHROUGH_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), + "INPUT_REDUCTION_OUTPUT": RulePolicy( + "INPUT_REDUCTION_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), + "SELF_MATMUL_OUTPUT": RulePolicy( + "SELF_MATMUL_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TORCH_COMPILE_CACHE": RulePolicy( "TORCH_COMPILE_CACHE", "performance_heuristic", "telemetry", TELEMETRY_ONLY, (), (), "downgrade", @@ -3767,6 +4380,9 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_cuda_graph_replay, detect_silent_fallback, detect_trivial_probe, + detect_input_passthrough_output, + detect_input_reduction_output, + detect_partial_self_matmul_output, detect_torch_compile_cache, detect_cuda_graph_python, detect_hardcoded_shapes, @@ -3805,6 +4421,9 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), ("trivial_probe", detect_trivial_probe), + ("input_passthrough_output", detect_input_passthrough_output), + ("input_reduction_output", detect_input_reduction_output), + ("partial_self_matmul_output", detect_partial_self_matmul_output), ("torch_compile_cache", detect_torch_compile_cache), ("cuda_graph_python", detect_cuda_graph_python), ("hardcoded_shapes", detect_hardcoded_shapes),