Skip to content

Commit 2d45a26

Browse files
Merge pull request #78 from hyperpolymath/fix/ci-corrective
fix(build): format lib/dune + track kernel_sublang.ml
2 parents 051841c + 06fe90f commit 2d45a26

2 files changed

Lines changed: 249 additions & 1 deletion

File tree

lib/dune

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,80 @@
22
(name affinescript)
33
(public_name affinescript)
44
(modes byte native)
5-
(modules ast borrow c_codegen cafe_face codegen codegen_gc desugar_traits effect error error_collector error_formatter face face_pragma formatter bash_codegen cuda_codegen faust_codegen gleam_codegen interp js_codegen js_face julia_codegen json_output lean_codegen llvm_codegen lua_codegen metal_codegen mlir_codegen nickel_codegen ocaml_codegen onnx_codegen onnx_proto opencl_codegen protobuf rescript_codegen rust_codegen spirv_codegen verilog_codegen why3_codegen lexer linter lsp_server lucid_face module_loader opt parse_driver parse parser parser_errors pseudocode_face python_face quantity resolve span symbol tea_bridge tea_cs_bridge tea_router token trait tw_interface tw_verify typecheck types unify value wasm wasm_encode wasm_gc wasm_gc_encode wasi_runtime wgsl_codegen)
5+
(modules
6+
ast
7+
borrow
8+
c_codegen
9+
cafe_face
10+
codegen
11+
codegen_gc
12+
desugar_traits
13+
effect
14+
error
15+
error_collector
16+
error_formatter
17+
face
18+
face_pragma
19+
formatter
20+
bash_codegen
21+
cuda_codegen
22+
faust_codegen
23+
gleam_codegen
24+
interp
25+
js_codegen
26+
js_face
27+
julia_codegen
28+
json_output
29+
kernel_sublang
30+
lean_codegen
31+
llvm_codegen
32+
lua_codegen
33+
metal_codegen
34+
mlir_codegen
35+
nickel_codegen
36+
ocaml_codegen
37+
onnx_codegen
38+
onnx_proto
39+
opencl_codegen
40+
protobuf
41+
rescript_codegen
42+
rust_codegen
43+
spirv_codegen
44+
verilog_codegen
45+
why3_codegen
46+
lexer
47+
linter
48+
lsp_server
49+
lucid_face
50+
module_loader
51+
opt
52+
parse_driver
53+
parse
54+
parser
55+
parser_errors
56+
pseudocode_face
57+
python_face
58+
quantity
59+
resolve
60+
span
61+
symbol
62+
tea_bridge
63+
tea_cs_bridge
64+
tea_router
65+
token
66+
trait
67+
tw_interface
68+
tw_verify
69+
typecheck
70+
types
71+
unify
72+
value
73+
wasm
74+
wasm_encode
75+
wasm_gc
76+
wasm_gc_encode
77+
wasi_runtime
78+
wgsl_codegen)
679
(libraries str unix sedlex fmt menhirLib yojson)
780
(preprocess
881
(pps ppx_deriving.show ppx_deriving.eq ppx_deriving.ord sedlex.ppx)))

lib/kernel_sublang.ml

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
(* SPDX-License-Identifier: PMPL-1.0-or-later *)
2+
(* SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell *)
3+
4+
(** Shared infrastructure for kernel-sublanguage backends.
5+
6+
The Tier-C backends — WGSL, SPIR-V, CUDA, Metal, OpenCL, Faust, ONNX,
7+
MLIR — all accept a *deliberately restricted* subset of AffineScript.
8+
Each had its own copy of:
9+
10+
- an [<Backend>_unsupported] exception + [unsupported] helper
11+
- a [pick_kernel] / [pick_entry] function that searches for a canonical
12+
entry name and falls back to the first [fn]
13+
- an [array_element] helper that strips ownership and matches
14+
[Array[T]] for a scalar [T]
15+
- a math-builtin allowlist for trig / pow / sqrt / etc.
16+
- an [is_unit_ty] check that accepts both [TyCon "Unit"] and [TyTuple []]
17+
18+
Per the [docs/guides/frontier-programming-practices/AI.a2ml] backends
19+
section: "If two new tier-C backends ship with overlapping
20+
restrictions, factor the validator before adding the third." We're
21+
well past three. This module is the factoring.
22+
23+
Per-target lowering of *concrete* type names (i32 vs i64 vs int vs
24+
long, f32 vs f64 vs float vs double) intentionally stays in each
25+
backend — that's not duplication, that's correctly different. *)
26+
27+
open Ast
28+
29+
(* ============================================================================
30+
Common exception
31+
============================================================================ *)
32+
33+
(** Raised by any kernel-sublanguage backend when the source falls outside
34+
its accepted subset. The string is the user-facing reason; backends
35+
typically prefix it with their name (e.g. "WGSL backend: ..."). *)
36+
exception Unsupported of string
37+
38+
let unsupported (msg : string) : 'a = raise (Unsupported msg)
39+
40+
(* ============================================================================
41+
Entry-function selection
42+
43+
Most kernel backends look for a function named [kernel], [process],
44+
[main], or [graph] and fall back to the first user-defined function in
45+
the program. The default name list mirrors what the existing backends
46+
used; callers can override.
47+
============================================================================ *)
48+
49+
let default_entry_names = ["kernel"; "process"; "main"; "graph"]
50+
51+
let pick_entry ?(names = default_entry_names) (program : program) : fn_decl =
52+
let fns = List.filter_map (function TopFn fd -> Some fd | _ -> None)
53+
program.prog_decls in
54+
let rec try_names = function
55+
| n :: rest ->
56+
(match List.find_opt (fun fd -> fd.fd_name.name = n) fns with
57+
| Some fd -> fd
58+
| None -> try_names rest)
59+
| [] ->
60+
(match fns with
61+
| fd :: _ -> fd
62+
| [] -> unsupported "no function found to lower as kernel")
63+
in
64+
try_names names
65+
66+
(* ============================================================================
67+
Type-shape predicates and helpers
68+
============================================================================ *)
69+
70+
let rec strip_ownership (te : type_expr) : type_expr =
71+
match te with
72+
| TyOwn t | TyRef t | TyMut t -> strip_ownership t
73+
| t -> t
74+
75+
(** Accept [Unit] either as the named type or as the empty tuple, since
76+
[-> ()] parses as [TyTuple []] but [-> Unit] parses as [TyCon "Unit"]. *)
77+
let is_unit_ty (te : type_expr) : bool =
78+
match strip_ownership te with
79+
| TyCon id when id.name = "Unit" -> true
80+
| TyTuple [] -> true
81+
| _ -> false
82+
83+
(** Standard primitive scalar names — the intersection that every kernel
84+
sublanguage we currently target supports. Backends restricting further
85+
(e.g. Faust's Float-only kernels) should layer their own check on top. *)
86+
let is_scalar_type_name (n : string) : bool =
87+
n = "Int" || n = "Float" || n = "Bool"
88+
89+
let is_scalar_type (te : type_expr) : bool =
90+
match strip_ownership te with
91+
| TyCon id -> is_scalar_type_name id.name
92+
| _ -> false
93+
94+
(** Decompose [Array[T]] for a scalar [T], returning the inner type-con
95+
name (e.g. ["Int"], ["Float"]). Returns [None] for non-array or
96+
non-scalar-element shapes. *)
97+
let array_element (te : type_expr) : string option =
98+
match strip_ownership te with
99+
| TyApp (id, [TyArg inner]) when id.name = "Array" ->
100+
(match strip_ownership inner with
101+
| TyCon id when is_scalar_type_name id.name -> Some id.name
102+
| _ -> None)
103+
| _ -> None
104+
105+
(** Same as [array_element] but raises [Unsupported] with a useful message
106+
on shapes that don't decompose. The expected shape string is included
107+
in the error so users see e.g. "expected Array[Float]" rather than a
108+
generic "type not allowed." *)
109+
let require_array_element (expected : string) (te : type_expr) : string =
110+
match array_element te with
111+
| Some name -> name
112+
| None ->
113+
unsupported (Printf.sprintf "expected %s for kernel buffer" expected)
114+
115+
(* ============================================================================
116+
Math builtins shared across kernel backends
117+
118+
Names are taken from the common subset (WGSL spec, GLSL, Metal stdlib,
119+
OpenCL, CUDA math.h, Faust). [int]/[float] are coercions; the rest are
120+
real math intrinsics. Backends may add target-specific names (e.g.
121+
Metal's [metal::float4]) but should start from this list.
122+
============================================================================ *)
123+
124+
let math_builtins : string list = [
125+
"sin"; "cos"; "tan"; "asin"; "acos"; "atan"; "atan2";
126+
"sinh"; "cosh"; "tanh";
127+
"exp"; "log"; "log2"; "log10"; "pow"; "sqrt"; "rsqrt";
128+
"floor"; "ceil"; "round"; "trunc"; "fract"; "fmod";
129+
"abs"; "min"; "max"; "clamp"; "sign";
130+
"mix"; "step"; "smoothstep";
131+
"int"; "float"; "i32"; "u32"; "f32";
132+
]
133+
134+
let is_math_builtin (name : string) : bool = List.mem name math_builtins
135+
136+
(* ============================================================================
137+
Validation helpers
138+
139+
These build on [is_scalar_type], [array_element], and [is_unit_ty] to
140+
produce friendly error messages. They are convenience wrappers — every
141+
backend can implement its own validation if its rules differ.
142+
============================================================================ *)
143+
144+
(** Require the function's return type satisfies [pred]; raise on violation. *)
145+
let validate_return (pred : type_expr -> bool) (expected : string) (fd : fn_decl) : unit =
146+
match fd.fd_ret_ty with
147+
| None -> () (* no annotation — caller decides whether to accept *)
148+
| Some t when pred t -> ()
149+
| _ -> unsupported
150+
(Printf.sprintf "kernel function must return %s" expected)
151+
152+
(** Require every parameter satisfies [pred]; raise on the first that
153+
doesn't, naming it. *)
154+
let validate_params (pred : type_expr -> bool) (expected : string)
155+
(fd : fn_decl) : unit =
156+
List.iter (fun (p : param) ->
157+
if not (pred p.p_ty) then
158+
unsupported
159+
(Printf.sprintf "parameter %s must be %s" p.p_name.name expected)
160+
) fd.fd_params
161+
162+
(** Common shape: first param is [Int] (the global invocation index),
163+
remaining params are [Array[T]] buffers, return type is [Unit]. Used
164+
by WGSL / SPIR-V / CUDA / Metal / OpenCL. *)
165+
let validate_compute_kernel_shape (fd : fn_decl) : unit =
166+
validate_return (fun t -> is_unit_ty t) "Unit or ()" fd;
167+
match fd.fd_params with
168+
| [] -> unsupported "kernel must take at least an Int index parameter"
169+
| first :: rest ->
170+
(match strip_ownership first.p_ty with
171+
| TyCon id when id.name = "Int" -> ()
172+
| _ -> unsupported "first kernel parameter must be Int (the global index)");
173+
List.iter (fun (p : param) ->
174+
ignore (require_array_element "Array[Int|Float]" p.p_ty)
175+
) rest

0 commit comments

Comments
 (0)