Skip to content

Support block_ptr/TensorDescriptor with extra_mask for loads#1768

Merged
jansel merged 16 commits intopytorch:mainfrom
hinriksnaer:extra-load-mask
Mar 26, 2026
Merged

Support block_ptr/TensorDescriptor with extra_mask for loads#1768
jansel merged 16 commits intopytorch:mainfrom
hinriksnaer:extra-load-mask

Conversation

@hinriksnaer
Copy link
Copy Markdown
Collaborator

Addresses #97

This is definitely somewhat of a new area for me and changes are based on my interpretation of the #97 description. Looking for some valuable feedback here to ensure there isn't something I missed or should approach from a different angle.

Approach

This enables block_ptr and TensorDescriptor support for hl.load(..., extra_mask=...). Previously, this would fallback to pointer indexing but this approach decomposes the mask into a separate epilogue (along with aten.view for cases where rank differs).

Before:

%load = call_function[target=hl.load](%x, [%tile_m, %tile_n, %tile_k], extra_mask=%mask)
%store = call_function[target=hl.store](%out, ..., %load)

After:

%load = call_function[target=hl.load](%x, [%tile_m, %tile_n, %tile_k], extra_mask=None)
%where = call_function[target=aten.where.ScalarOther](%mask, %load, 0)
%store = call_function[target=hl.store](%out, ..., %where)

For non-matching mask rank

Before:

%load = call_function[target=hl.load](%x, [%tile_m, %tile_n, %tile_k], extra_mask=%mask)
%store = call_function[target=hl.store](%out, ..., %load)

After:

%load = call_function[target=hl.load](%x, [%tile_m, %tile_n, %tile_k], extra_mask=None)
%view = call_function[target=aten.view](%mask, [block_M, 1, 1])
%where = call_function[target=aten.where.ScalarOther](%view, %load, 0)
%store = call_function[target=hl.store](%out, ..., %where)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 20, 2026
Copy link
Copy Markdown
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we always decompose these it will be a performance regression for non-block ptrs.

I also worry this will interact badly with the existing mask_to op and optimizations to remove masks. (We propagate masking information.)

Maybe we should also do this pass at codegen time and only if block pointers are chosen for the given op.

@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

hinriksnaer commented Mar 20, 2026

These are all very good points and I appreciate the broader context.

Sounds to me like next steps in the right direction would be:

  • Remove the decomposition in the fx graph
  • move the functionality to the generated triton code if TensorDescriptor / block ptrs are selected
  • modify the is_supported to not reject masked loads with non pointer indexing strategy
  • update tests

anything you would like to add?

@hinriksnaer hinriksnaer requested a review from jansel March 23, 2026 16:42
@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

Made changes based on your feedback @jansel.

Assuming this looks good, would you want me to add support to lower rank masks in a future PR? e.g.

out[tile_m, tile_n] = hl.load(
    x, [tile_m, tile_n], extra_mask=row_mask[tile_m]
)

@hinriksnaer
Copy link
Copy Markdown
Collaborator Author

@jansel all tests are passing.

@jansel jansel merged commit 733351e into pytorch:main Mar 26, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants