Add block sparse linear and locally connected layers#6
Conversation
- Use pytorch native blocksparse tensors following [this blog](https://pytorch.org/blog/speeding-up-vits/) rather than triton blocksparse, which seems not very stable. See e.g. [this recent PR](triton-lang/triton#4156) where `triton.ops` were deprecated. - Represent weights in `BlockSparseLinear` as block-sparse tensor rather than dense tensor. This will save a lot of gpu memory. - Change `BlockSparseLocallyConnected` interface to more closely match `nn.Conv2d`. Except remove support for padding and stride. For now we should be able to restrict to `padding="same"`, `stride=1` (?) - Rewrite function to construct local connectivity matrix to directly construct a sparse rather than dense matrix. Use vectorized operations rather than for loops for construction. This should save a lot of memory and run faster. - Add support for multiple input and output channels and depthwise convolution. The channels axis can either be first or last. For depthwise convolution, first should be more efficient (more block sparsity). TODO: - Finish testing on cuda. native blocksparse matmul is not implemented on CPU.
|
I updated the implementation:
|
If we want to use this for All TNNs, to make the dimensions work out we have to use different strides for each layer |
Could not use `sparse_bsr` layout weight parameter. It fails to map to cuda correctly when calling `model.cuda()`. This is probably a bug that should be reported. But as a workaround I just unpack the `crow_indices`, `col_indices` as buffers and store the sparse bsr weight values as a standard strided parameter. Then I construct the sparse bsr weight tensor on the fly during forward. TODO: backward does not work. Raises ``` RuntimeError: addmm: computation on CUDA is not implemented for Strided + Strided @ SparseBsr ```
Ok good point. Then maybe we should add back the support you had for strides and padding to |
|
|
||
| # conv kernel index offsets. note that the kernel width is required to be odd. | ||
| # (k^2, 2) | ||
| kernel_half_width = (kernel_size - 1) // 2 |
There was a problem hiding this comment.
This is a bit restrictive, can we adapt this to also include even kernel size by doing something like this?
if kernel_size % 2 == 0:
kernel_half_width = kernel_size // 2
kernel_indices = torch.cartesian_prod(
torch.arange(-kernel_half_width, kernel_half_width),
torch.arange(-kernel_half_width, kernel_half_width),
)
else:
kernel_half_width = (kernel_size - 1) // 2
kernel_indices = torch.cartesian_prod(
torch.arange(-kernel_half_width, kernel_half_width + 1),
torch.arange(-kernel_half_width, kernel_half_width + 1),
)
There was a problem hiding this comment.
Yes I think you're right, something like this would probably be better. Although it feels like it should be possible to make the code shorter.
More generally, it would probably be best to have exactly the same interface and behavior as Conv2d. What I have now takes a few shortcuts.
No description provided.