Skip to content

Add block sparse linear and locally connected layers#6

Draft
clane9 wants to merge 9 commits into
mainfrom
blocksparse
Draft

Add block sparse linear and locally connected layers#6
clane9 wants to merge 9 commits into
mainfrom
blocksparse

Conversation

@clane9
Copy link
Copy Markdown
Owner

@clane9 clane9 commented May 1, 2024

No description provided.

alismil and others added 6 commits July 1, 2024 14:05
- Use pytorch native blocksparse tensors following [this
  blog](https://pytorch.org/blog/speeding-up-vits/) rather than triton
  blocksparse, which seems not very stable. See e.g. [this recent
  PR](triton-lang/triton#4156) where
  `triton.ops` were deprecated.
- Represent weights in `BlockSparseLinear` as block-sparse tensor rather
  than dense tensor. This will save a lot of gpu memory.
- Change `BlockSparseLocallyConnected` interface to more closely match
  `nn.Conv2d`. Except remove support for padding and stride. For now we
  should be able to restrict to `padding="same"`, `stride=1` (?)
- Rewrite function to construct local connectivity matrix to directly
  construct a sparse rather than dense matrix. Use vectorized
  operations rather than for loops for construction. This should save a
  lot of memory and run faster.
- Add support for multiple input and output channels and depthwise
  convolution. The channels axis can either be first or last. For
  depthwise convolution, first should be more efficient (more block
  sparsity).

TODO:
- Finish testing on cuda. native blocksparse matmul is not implemented on CPU.
@clane9
Copy link
Copy Markdown
Owner Author

clane9 commented Jul 2, 2024

I updated the implementation:

  • Use pytorch native blocksparse tensors following this blog rather than triton blocksparse, which seems not yet stable.
  • Represent weights in BlockSparseLinear as block-sparse tensor rather than dense tensor. This will save a lot of gpu memory.
  • Change BlockSparseLocallyConnected interface to more closely match nn.Conv2d. Except remove support for padding and stride. For now we should be able to restrict to padding="same", stride=1 (@alismil what do you think?).
  • Rewrite function for constructing local connectivity matrix to directly construct a sparse rather than dense matrix. Use vectorized operations rather than for loops for construction. This should save a lot of memory and run faster.
  • Add support for multiple input and output channels and depthwise convolution. The channels axis can either be first or last. For depthwise convolution, first should be more efficient (more block sparsity).

Comment thread columnformers/models/layers.py Outdated
@alismil
Copy link
Copy Markdown
Collaborator

alismil commented Jul 2, 2024

For now we should be able to restrict to padding="same", stride=1 (@alismil what do you think?).

If we want to use this for All TNNs, to make the dimensions work out we have to use different strides for each layer

Could not use `sparse_bsr` layout weight parameter. It fails to map to
cuda correctly when calling `model.cuda()`. This is probably a bug that
should be reported. But as a workaround I just unpack the
`crow_indices`, `col_indices` as buffers and store the sparse bsr weight
values as a standard strided parameter. Then I construct the sparse bsr
weight tensor on the fly during forward.

TODO: backward does not work. Raises

```
RuntimeError: addmm: computation on CUDA is not implemented for Strided + Strided @ SparseBsr
```
@clane9
Copy link
Copy Markdown
Owner Author

clane9 commented Jul 3, 2024

If we want to use this for All TNNs, to make the dimensions work out we have to use different strides for each layer

Ok good point. Then maybe we should add back the support you had for strides and padding to _sparse_local_connectivity. Hopefully it doesn't complicate things too much.


# conv kernel index offsets. note that the kernel width is required to be odd.
# (k^2, 2)
kernel_half_width = (kernel_size - 1) // 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit restrictive, can we adapt this to also include even kernel size by doing something like this?

if kernel_size % 2 == 0:
    kernel_half_width = kernel_size // 2
    kernel_indices = torch.cartesian_prod(
        torch.arange(-kernel_half_width, kernel_half_width),
        torch.arange(-kernel_half_width, kernel_half_width),
    )
else:
    kernel_half_width = (kernel_size - 1) // 2
    kernel_indices = torch.cartesian_prod(
        torch.arange(-kernel_half_width, kernel_half_width + 1),
        torch.arange(-kernel_half_width, kernel_half_width + 1),
    )

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you're right, something like this would probably be better. Although it feels like it should be possible to make the code shorter.

More generally, it would probably be best to have exactly the same interface and behavior as Conv2d. What I have now takes a few shortcuts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants