Skip to content

[metal] Support elementwise kernels with >1D tensors#1765

Closed
aditvenk wants to merge 1 commit intoaditvenk/stack/11from
aditvenk/stack/12
Closed

[metal] Support elementwise kernels with >1D tensors#1765
aditvenk wants to merge 1 commit intoaditvenk/stack/11from
aditvenk/stack/12

Conversation

@aditvenk
Copy link
Copy Markdown
Contributor

@aditvenk aditvenk commented Mar 20, 2026

Stacked PRs:


[metal] Support elementwise kernels with >1D tensors

Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

  • grid_index_expr: return offset_var (from PID decomposition) instead
    of _gid, so each dimension gets its own computed index
  • program_id_expr: return _gid directly instead of _pid0, since the
    flat thread ID is the program ID for per-thread dispatch
  • build_launcher_args: use product of all block_sizes for _block_size
    so enough threads are dispatched (e.g., 32*32=1024 for 2D)
  • MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
    linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
    Metal device pointers

aditvenk added a commit that referenced this pull request Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from 152bb59 to c9b78ba Compare March 20, 2026 05:36
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 224be1f to 17f608d Compare March 20, 2026 05:36
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 20, 2026
@aditvenk aditvenk marked this pull request as draft March 20, 2026 06:19
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 20, 2026 06:19
aditvenk added a commit that referenced this pull request Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 17f608d to 9ec5962 Compare March 20, 2026 06:19
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 20, 2026 06:20
@aditvenk aditvenk marked this pull request as ready for review March 20, 2026 06:20
@aditvenk aditvenk marked this pull request as draft March 20, 2026 16:43
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 20, 2026 16:43
aditvenk added a commit that referenced this pull request Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 9ec5962 to c25549d Compare March 20, 2026 16:43
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 20, 2026 16:43
@aditvenk aditvenk marked this pull request as ready for review March 20, 2026 16:43
@aditvenk aditvenk requested review from jansel and oulgen and removed request for jansel March 20, 2026 16:51
Comment thread helion/_compiler/metal/msl_ast_walker.py Outdated
Comment thread helion/_compiler/backend.py
@aditvenk aditvenk force-pushed the aditvenk/stack/11 branch from f3d9b96 to e578915 Compare March 21, 2026 02:31
aditvenk added a commit that referenced this pull request Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from c25549d to 851416c Compare March 21, 2026 02:31
@aditvenk aditvenk marked this pull request as draft March 21, 2026 02:45
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 065c05f to 40197f8 Compare March 21, 2026 05:54
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 21, 2026 05:54
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 05:55
@aditvenk aditvenk marked this pull request as draft March 21, 2026 06:00
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 21, 2026 06:00
aditvenk added a commit that referenced this pull request Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 40197f8 to bcb438d Compare March 21, 2026 06:00
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 21, 2026 06:01
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 06:01
@aditvenk aditvenk marked this pull request as draft March 21, 2026 06:57
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 21, 2026 06:57
aditvenk added a commit that referenced this pull request Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from bcb438d to 4e1b33b Compare March 21, 2026 06:57
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 21, 2026 06:57
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 06:57
@aditvenk aditvenk marked this pull request as draft March 21, 2026 16:45
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 21, 2026 16:45
aditvenk added a commit that referenced this pull request Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 4e1b33b to 9862615 Compare March 21, 2026 16:45
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 21, 2026 16:45
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 16:46
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:

- grid_index_expr: return offset_var (from PID decomposition) instead
  of _gid, so each dimension gets its own computed index
- program_id_expr: return _gid directly instead of _pid0, since the
  flat thread ID is the program ID for per-thread dispatch
- build_launcher_args: use product of all block_sizes for _block_size
  so enough threads are dispatched (e.g., 32*32=1024 for 2D)
- MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and
  linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
  Metal device pointers

stack-info: PR: #1765, branch: aditvenk/stack/12
@aditvenk aditvenk marked this pull request as draft March 21, 2026 16:54
@aditvenk aditvenk changed the base branch from aditvenk/stack/11 to main March 21, 2026 16:54
@aditvenk aditvenk force-pushed the aditvenk/stack/12 branch from 9862615 to 70ddbaf Compare March 21, 2026 16:54
@aditvenk aditvenk changed the base branch from main to aditvenk/stack/11 March 21, 2026 16:54
@aditvenk aditvenk marked this pull request as ready for review March 21, 2026 16:54
@aditvenk aditvenk marked this pull request as draft March 22, 2026 04:47
@aditvenk
Copy link
Copy Markdown
Contributor Author

Will shortly open a new PR that refactors the commits in the stack differently

@aditvenk aditvenk closed this Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants