[metal] Support elementwise kernels with >1D tensors#1765
Closed
aditvenk wants to merge 1 commit intoaditvenk/stack/11from
Closed
[metal] Support elementwise kernels with >1D tensors#1765aditvenk wants to merge 1 commit intoaditvenk/stack/11from
aditvenk wants to merge 1 commit intoaditvenk/stack/11from
Conversation
aditvenk
added a commit
that referenced
this pull request
Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
152bb59 to
c9b78ba
Compare
224be1f to
17f608d
Compare
This was referenced Mar 20, 2026
aditvenk
added a commit
that referenced
this pull request
Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
17f608d to
9ec5962
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 20, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
9ec5962 to
c25549d
Compare
jansel
reviewed
Mar 20, 2026
f3d9b96 to
e578915
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
c25549d to
851416c
Compare
065c05f to
40197f8
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
40197f8 to
bcb438d
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
bcb438d to
4e1b33b
Compare
aditvenk
added a commit
that referenced
this pull request
Mar 21, 2026
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
4e1b33b to
9862615
Compare
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n])) on the Metal backend by fixing four issues in the per-thread dispatch model: - grid_index_expr: return offset_var (from PID decomposition) instead of _gid, so each dimension gets its own computed index - program_id_expr: return _gid directly instead of _pid0, since the flat thread ID is the program ID for per-thread dispatch - build_launcher_args: use product of all block_sizes for _block_size so enough threads are dispatched (e.g., 32*32=1024 for 2D) - MSL walker: emit constexpr _BLOCK_SIZE_N = 1 definitions and linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat Metal device pointers stack-info: PR: #1765, branch: aditvenk/stack/12
9862615 to
70ddbaf
Compare
Contributor
Author
|
Will shortly open a new PR that refactors the commits in the stack differently |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[metal] Support elementwise kernels with >1D tensors
Enable multi-dimensional elementwise kernels (e.g., hl.tile([m, n]))
on the Metal backend by fixing four issues in the per-thread dispatch
model:
of _gid, so each dimension gets its own computed index
flat thread ID is the program ID for per-thread dispatch
so enough threads are dispatched (e.g., 32*32=1024 for 2D)
linearize multi-dim subscripts (x[i, j] → x[i*stride+j]) for flat
Metal device pointers