Skip to content

[metal] MslAstWalker + MetalBackend MSL codegen + device function hook + tests#1764

Closed
aditvenk wants to merge 1 commit intoaditvenk/stack/10from
aditvenk/stack/11
Closed

[metal] MslAstWalker + MetalBackend MSL codegen + device function hook + tests#1764
aditvenk wants to merge 1 commit intoaditvenk/stack/10from
aditvenk/stack/11

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

@aditvenk aditvenk commented Mar 20, 2026

Stacked PRs:


[metal] MslAstWalker + MetalBackend MSL codegen + device function hook + tests

  • Add helion/_compiler/metal/ package:
    • msl_ast_walker.py: MslAstWalker class with _generate_elementwise()
      and AST-to-MSL C++ converter functions. Handles tl.cast, tl.where,
      tl.full, tl.reshape, tl.sigmoid, tl.sqrt_rn, libdevice.*,
      triton_helpers.maximum/minimum, select, ast.If for masked stores.
  • Modify codegen_function_def to post-process for Metal: standard path
    runs first (variable renames, constexpr inlining, scalar preambles),
    then MetalBackend.generate_msl_function extracts the FunctionDef body,
    creates an MslAstWalker, and replaces with a zero-arg function
    returning (msl_source, kernel_name).
  • Add MetalBackend methods: transform_host_arg (wraps scalar args as
    float32 1-element buffer tensors), build_launcher_args, inline_constexpr,
    generate_msl_function, and expression generators (program_id_expr,
    cdiv_expr, sympy_printer_expr, grid_index_expr, inductor_op_overrides)
  • In a subsequent change, we will avoid passing scalar args as 1-element
    buffer tensors, and instead bake them directly into the MSL source.
  • 26 tests: arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
    activations (relu/silu/gelu), math ops (exp/log/sqrt/abs/sin/cos/clamp),
    dtypes (float16/bfloat16/int32), bounds masking (OOB sentinel checks,
    codegen mask assertions)

@aditvenk aditvenk force-pushed the aditvenk/stack/10 branch from f238861 to e4ffd2c Compare March 20, 2026 05:36
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from 152bb59 to c9b78ba Compare March 20, 2026 05:36
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 20, 2026
@aditvenk aditvenk changed the title [metal] MslAstWalker + MetalBackend MSL codegen + device function hook + tests [metal] Add lowering for elementwise kernels Mar 20, 2026
@aditvenk aditvenk marked this pull request as draft March 20, 2026 06:19
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 20, 2026 06:19
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from c9b78ba to fdcf824 Compare March 20, 2026 06:19
@aditvenk aditvenk changed the title [metal] Add lowering for elementwise kernels [metal] MslAstWalker + MetalBackend MSL codegen + device function hook + tests Mar 20, 2026
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 20, 2026 06:20
@aditvenk aditvenk marked this pull request as ready for review March 20, 2026 06:20
@aditvenk aditvenk marked this pull request as draft March 20, 2026 16:43
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 20, 2026 16:43
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from fdcf824 to f3d9b96 Compare March 20, 2026 16:43
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 20, 2026 16:43
@aditvenk aditvenk marked this pull request as ready for review March 20, 2026 16:43
@aditvenk aditvenk requested review from jansel and oulgen and removed request for jansel March 20, 2026 16:49
@aditvenk aditvenk force-pushed the aditvenk/stack/10 branch from 6f5cc8a to 4843e8c Compare March 21, 2026 02:30
aditvenk added a commit that referenced this pull request Mar 21, 2026
…k + tests

- Add helion/_compiler/metal/ package:
  - msl_ast_walker.py: MslAstWalker class with _generate_elementwise()
    and AST-to-MSL C++ converter functions. Handles tl.cast, tl.where,
    tl.full, tl.reshape, tl.sigmoid, tl.sqrt_rn, libdevice.*,
    triton_helpers.maximum/minimum, select, ast.If for masked stores.
- Modify codegen_function_def to post-process for Metal: standard path
  runs first (variable renames, constexpr inlining, scalar preambles),
  then MetalBackend.generate_msl_function extracts the FunctionDef body,
  creates an MslAstWalker, and replaces with a zero-arg function
  returning (msl_source, kernel_name).
- Add MetalBackend methods: transform_host_arg (wraps scalar args as
  float32 1-element buffer tensors), build_launcher_args, inline_constexpr,
  generate_msl_function, and expression generators (program_id_expr,
  cdiv_expr, sympy_printer_expr, grid_index_expr, inductor_op_overrides)
- In a subsequent change, we will avoid passing scalar args as 1-element
  buffer tensors, and instead bake them directly into the MSL source.
- 26 tests: arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
  activations (relu/silu/gelu), math ops (exp/log/sqrt/abs/sin/cos/clamp),
  dtypes (float16/bfloat16/int32), bounds masking (OOB sentinel checks,
  codegen mask assertions)

stack-info: PR: #1764, branch: aditvenk/stack/11
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from f3d9b96 to e578915 Compare March 21, 2026 02:31
@aditvenk aditvenk marked this pull request as draft March 21, 2026 02:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 02:45
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from e578915 to 0bc0be3 Compare March 21, 2026 02:45
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 03:13
@aditvenk aditvenk marked this pull request as draft March 21, 2026 05:54
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 05:54
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from a8863d9 to 7c001bc Compare March 21, 2026 05:54
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 21, 2026 05:54
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 05:55
@aditvenk aditvenk marked this pull request as draft March 21, 2026 06:00
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 06:00
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from 7c001bc to 3cd1378 Compare March 21, 2026 06:00
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 21, 2026 06:01
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 06:01
@aditvenk aditvenk marked this pull request as draft March 21, 2026 06:56
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 06:57
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from 3cd1378 to c5cdfa3 Compare March 21, 2026 06:57
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 21, 2026 06:57
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 06:57
@aditvenk aditvenk marked this pull request as draft March 21, 2026 16:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 16:45
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from c5cdfa3 to ec62c4d Compare March 21, 2026 16:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 21, 2026 16:45
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 16:46
…k + tests

- Add helion/_compiler/metal/ package:
  - msl_ast_walker.py: MslAstWalker class with _generate_elementwise()
    and AST-to-MSL C++ converter functions. Handles tl.cast, tl.where,
    tl.full, tl.reshape, tl.sigmoid, tl.sqrt_rn, libdevice.*,
    triton_helpers.maximum/minimum, select, ast.If for masked stores.
- Modify codegen_function_def to post-process for Metal: standard path
  runs first (variable renames, constexpr inlining, scalar preambles),
  then MetalBackend.generate_msl_function extracts the FunctionDef body,
  creates an MslAstWalker, and replaces with a zero-arg function
  returning (msl_source, kernel_name).
- Add MetalBackend methods: transform_host_arg (wraps scalar args as
  float32 1-element buffer tensors), build_launcher_args, inline_constexpr,
  generate_msl_function, and expression generators (program_id_expr,
  cdiv_expr, sympy_printer_expr, grid_index_expr, inductor_op_overrides)
- In a subsequent change, we will avoid passing scalar args as 1-element
  buffer tensors, and instead bake them directly into the MSL source.
- 26 tests: arithmetic (add/sub/mul/div/neg), scalar args (saxpy),
  activations (relu/silu/gelu), math ops (exp/log/sqrt/abs/sin/cos/clamp),
  dtypes (float16/bfloat16/int32), bounds masking (OOB sentinel checks,
  codegen mask assertions)

stack-info: PR: #1764, branch: aditvenk/stack/11
@aditvenk aditvenk marked this pull request as draft March 21, 2026 16:54
@aditvenk aditvenk changed the base branch from aditvenk/stack/10 to main March 21, 2026 16:54
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from ec62c4d to 9f36d70 Compare March 21, 2026 16:54
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/10 March 21, 2026 16:54
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 16:54
@aditvenk aditvenk marked this pull request as draft March 22, 2026 04:47
@aditvenk
Copy link
Copy Markdown
Contributor Author

Will shortly open a new PR that refactors the commits in the stack differently

@aditvenk aditvenk closed this Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant