|
| 1 | +(* SPDX-License-Identifier: PMPL-1.0-or-later *) |
| 2 | +(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *) |
| 3 | + |
| 4 | +(** Shared infrastructure for kernel-sublanguage backends. |
| 5 | +
|
| 6 | + The Tier-C backends — WGSL, SPIR-V, CUDA, Metal, OpenCL, Faust, ONNX, |
| 7 | + MLIR — all accept a *deliberately restricted* subset of AffineScript. |
| 8 | + Each had its own copy of: |
| 9 | +
|
| 10 | + - an [<Backend>_unsupported] exception + [unsupported] helper |
| 11 | + - a [pick_kernel] / [pick_entry] function that searches for a canonical |
| 12 | + entry name and falls back to the first [fn] |
| 13 | + - an [array_element] helper that strips ownership and matches |
| 14 | + [Array[T]] for a scalar [T] |
| 15 | + - a math-builtin allowlist for trig / pow / sqrt / etc. |
| 16 | + - an [is_unit_ty] check that accepts both [TyCon "Unit"] and [TyTuple []] |
| 17 | +
|
| 18 | + Per the [docs/guides/frontier-programming-practices/AI.a2ml] backends |
| 19 | + section: "If two new tier-C backends ship with overlapping |
| 20 | + restrictions, factor the validator before adding the third." We're |
| 21 | + well past three. This module is the factoring. |
| 22 | +
|
| 23 | + Per-target lowering of *concrete* type names (i32 vs i64 vs int vs |
| 24 | + long, f32 vs f64 vs float vs double) intentionally stays in each |
| 25 | + backend — that's not duplication, that's correctly different. *) |
| 26 | + |
| 27 | +open Ast |
| 28 | + |
| 29 | +(* ============================================================================ |
| 30 | + Common exception |
| 31 | + ============================================================================ *) |
| 32 | + |
| 33 | +(** Raised by any kernel-sublanguage backend when the source falls outside |
| 34 | + its accepted subset. The string is the user-facing reason; backends |
| 35 | + typically prefix it with their name (e.g. "WGSL backend: ..."). *) |
| 36 | +exception Unsupported of string |
| 37 | + |
| 38 | +let unsupported (msg : string) : 'a = raise (Unsupported msg) |
| 39 | + |
| 40 | +(* ============================================================================ |
| 41 | + Entry-function selection |
| 42 | +
|
| 43 | + Most kernel backends look for a function named [kernel], [process], |
| 44 | + [main], or [graph] and fall back to the first user-defined function in |
| 45 | + the program. The default name list mirrors what the existing backends |
| 46 | + used; callers can override. |
| 47 | + ============================================================================ *) |
| 48 | + |
| 49 | +let default_entry_names = ["kernel"; "process"; "main"; "graph"] |
| 50 | + |
| 51 | +let pick_entry ?(names = default_entry_names) (program : program) : fn_decl = |
| 52 | + let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None) |
| 53 | + program.prog_decls in |
| 54 | + let rec try_names = function |
| 55 | + | n :: rest -> |
| 56 | + (match List.find_opt (fun fd -> fd.fd_name.name = n) fns with |
| 57 | + | Some fd -> fd |
| 58 | + | None -> try_names rest) |
| 59 | + | [] -> |
| 60 | + (match fns with |
| 61 | + | fd :: _ -> fd |
| 62 | + | [] -> unsupported "no function found to lower as kernel") |
| 63 | + in |
| 64 | + try_names names |
| 65 | + |
| 66 | +(* ============================================================================ |
| 67 | + Type-shape predicates and helpers |
| 68 | + ============================================================================ *) |
| 69 | + |
| 70 | +let rec strip_ownership (te : type_expr) : type_expr = |
| 71 | + match te with |
| 72 | + | TyOwn t | TyRef t | TyMut t -> strip_ownership t |
| 73 | + | t -> t |
| 74 | + |
| 75 | +(** Accept [Unit] either as the named type or as the empty tuple, since |
| 76 | + [-> ()] parses as [TyTuple []] but [-> Unit] parses as [TyCon "Unit"]. *) |
| 77 | +let is_unit_ty (te : type_expr) : bool = |
| 78 | + match strip_ownership te with |
| 79 | + | TyCon id when id.name = "Unit" -> true |
| 80 | + | TyTuple [] -> true |
| 81 | + | _ -> false |
| 82 | + |
| 83 | +(** Standard primitive scalar names — the intersection that every kernel |
| 84 | + sublanguage we currently target supports. Backends restricting further |
| 85 | + (e.g. Faust's Float-only kernels) should layer their own check on top. *) |
| 86 | +let is_scalar_type_name (n : string) : bool = |
| 87 | + n = "Int" || n = "Float" || n = "Bool" |
| 88 | + |
| 89 | +let is_scalar_type (te : type_expr) : bool = |
| 90 | + match strip_ownership te with |
| 91 | + | TyCon id -> is_scalar_type_name id.name |
| 92 | + | _ -> false |
| 93 | + |
| 94 | +(** Decompose [Array[T]] for a scalar [T], returning the inner type-con |
| 95 | + name (e.g. ["Int"], ["Float"]). Returns [None] for non-array or |
| 96 | + non-scalar-element shapes. *) |
| 97 | +let array_element (te : type_expr) : string option = |
| 98 | + match strip_ownership te with |
| 99 | + | TyApp (id, [TyArg inner]) when id.name = "Array" -> |
| 100 | + (match strip_ownership inner with |
| 101 | + | TyCon id when is_scalar_type_name id.name -> Some id.name |
| 102 | + | _ -> None) |
| 103 | + | _ -> None |
| 104 | + |
| 105 | +(** Same as [array_element] but raises [Unsupported] with a useful message |
| 106 | + on shapes that don't decompose. The expected shape string is included |
| 107 | + in the error so users see e.g. "expected Array[Float]" rather than a |
| 108 | + generic "type not allowed." *) |
| 109 | +let require_array_element (expected : string) (te : type_expr) : string = |
| 110 | + match array_element te with |
| 111 | + | Some name -> name |
| 112 | + | None -> |
| 113 | + unsupported (Printf.sprintf "expected %s for kernel buffer" expected) |
| 114 | + |
| 115 | +(* ============================================================================ |
| 116 | + Math builtins shared across kernel backends |
| 117 | +
|
| 118 | + Names are taken from the common subset (WGSL spec, GLSL, Metal stdlib, |
| 119 | + OpenCL, CUDA math.h, Faust). [int]/[float] are coercions; the rest are |
| 120 | + real math intrinsics. Backends may add target-specific names (e.g. |
| 121 | + Metal's [metal::float4]) but should start from this list. |
| 122 | + ============================================================================ *) |
| 123 | + |
| 124 | +let math_builtins : string list = [ |
| 125 | + "sin"; "cos"; "tan"; "asin"; "acos"; "atan"; "atan2"; |
| 126 | + "sinh"; "cosh"; "tanh"; |
| 127 | + "exp"; "log"; "log2"; "log10"; "pow"; "sqrt"; "rsqrt"; |
| 128 | + "floor"; "ceil"; "round"; "trunc"; "fract"; "fmod"; |
| 129 | + "abs"; "min"; "max"; "clamp"; "sign"; |
| 130 | + "mix"; "step"; "smoothstep"; |
| 131 | + "int"; "float"; "i32"; "u32"; "f32"; |
| 132 | +] |
| 133 | + |
| 134 | +let is_math_builtin (name : string) : bool = List.mem name math_builtins |
| 135 | + |
| 136 | +(* ============================================================================ |
| 137 | + Validation helpers |
| 138 | +
|
| 139 | + These build on [is_scalar_type], [array_element], and [is_unit_ty] to |
| 140 | + produce friendly error messages. They are convenience wrappers — every |
| 141 | + backend can implement its own validation if its rules differ. |
| 142 | + ============================================================================ *) |
| 143 | + |
| 144 | +(** Require the function's return type satisfies [pred]; raise on violation. *) |
| 145 | +let validate_return (pred : type_expr -> bool) (expected : string) (fd : fn_decl) : unit = |
| 146 | + match fd.fd_ret_ty with |
| 147 | + | None -> () (* no annotation — caller decides whether to accept *) |
| 148 | + | Some t when pred t -> () |
| 149 | + | _ -> unsupported |
| 150 | + (Printf.sprintf "kernel function must return %s" expected) |
| 151 | + |
| 152 | +(** Require every parameter satisfies [pred]; raise on the first that |
| 153 | + doesn't, naming it. *) |
| 154 | +let validate_params (pred : type_expr -> bool) (expected : string) |
| 155 | + (fd : fn_decl) : unit = |
| 156 | + List.iter (fun (p : param) -> |
| 157 | + if not (pred p.p_ty) then |
| 158 | + unsupported |
| 159 | + (Printf.sprintf "parameter %s must be %s" p.p_name.name expected) |
| 160 | + ) fd.fd_params |
| 161 | + |
| 162 | +(** Common shape: first param is [Int] (the global invocation index), |
| 163 | + remaining params are [Array[T]] buffers, return type is [Unit]. Used |
| 164 | + by WGSL / SPIR-V / CUDA / Metal / OpenCL. *) |
| 165 | +let validate_compute_kernel_shape (fd : fn_decl) : unit = |
| 166 | + validate_return (fun t -> is_unit_ty t) "Unit or ()" fd; |
| 167 | + match fd.fd_params with |
| 168 | + | [] -> unsupported "kernel must take at least an Int index parameter" |
| 169 | + | first :: rest -> |
| 170 | + (match strip_ownership first.p_ty with |
| 171 | + | TyCon id when id.name = "Int" -> () |
| 172 | + | _ -> unsupported "first kernel parameter must be Int (the global index)"); |
| 173 | + List.iter (fun (p : param) -> |
| 174 | + ignore (require_array_element "Array[Int|Float]" p.p_ty) |
| 175 | + ) rest |
0 commit comments