Rework index manipulation API#416
Conversation
|
After the build completes, the updated documentation will be available here |
f45ab42 to
2841304
Compare
|
|
ff15f1e to
c4408b3
Compare
a0bc84b to
2c44dca
Compare
b833340 to
5a576f0
Compare
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
[skip ci]
Docs content is being added back in a stacked follow-up PR to keep this one reviewable. The minimal docs/src/lib/tensors.md change is kept here because removing the @docs block for the now-deprecated add_permute!/add_braid!/add_transpose! wrappers is required for the docs build to succeed.
cae3748 to
26c844a
Compare
|
I've separated out the docs changes in an attempt to not conflate the two distinct changes in this PR. I think right now everything here should pass, so it should be ready for a round of review, before it can be merged. |
| function braid( | ||
| t::AdjointTensorMap, (p₁, p₂)::Index2Tuple, levels::IndexTuple; | ||
| kwargs... | ||
| ) | ||
| p₁′ = adjointtensorindices(t, p₂) | ||
| p₂′ = adjointtensorindices(t, p₁) | ||
| perm = adjointtensorindices(adjoint(t), ntuple(identity, numind(t))) | ||
| levels′ = TupleTools.getindices(levels, perm) | ||
| return adjoint(braid(adjoint(t), (p₁′, p₂′), levels′; kwargs...)) | ||
| end |
There was a problem hiding this comment.
Is this a completely new definition. I am wondering about its correctness, in particular with respect to the definition of levels′. Given that the adjoint of an overbraid is an underbraid, it might be that we want to change levels to map(-, levels), in combination with applying the permuation perm?
There was a problem hiding this comment.
Ok in some simple example that I tested it with, the current implementation is correct, so never mind.
There was a problem hiding this comment.
I did end up adding some test for this since indeed it is a new implementation, I think I've convinced myself that while indeed you want underbraids on the adjoint, this ends up being the case because the index order gets reversed
There was a problem hiding this comment.
What index order? I am still confused. The index mappings are such that you do the corresponding braid on the adjoint tensor? In what sense is that reversing index order?
As a side note (since I like derailing PRs 😄 ), one thing I noticed in trying to test this, is that copy(::AdjointTensor) produces a new ::AdjointTensor. This is different from Matrix, where copy of an adjoint produces a regular Matrix. Also TensorMap(::AdjointTensorMap) doesn't work, so it took me a while to find a good way to reinstantiate an AdjointTensorMap as its corresponding TensorMap.
| end | ||
| # 2. Recoupling: buffer_dst = buffer_src * U^T (each output tree is a linear | ||
| # combination of input trees weighted by the recoupling coefficients). | ||
| U′ = Adapt.adapt(typeof(data_dst), StridedView(U)) |
There was a problem hiding this comment.
What is the point of wrapping U in a StridedView here before the adapt call? data_dst is not a StridedView at this point, right? Is this equivalent to StridedView(Adapt.adapt(typeof(data_dst), U))?
There was a problem hiding this comment.
It is in spirit, the reason for the change in order is that storagetype yields something with ndims = 1, while U is a Matrix, so we have the freedom to be slightly more liberal with the strided implementation but the standard CuArray doesn't capture that
| # using a trivial permutation so the layout is canonical before the matmul. | ||
| @inbounds for (i, struct_src_i) in enumerate(structs_src) | ||
| TO.tensoradd!( | ||
| sreshape(buffer_src[:, i], sz_src), StridedView(data_src, sz_src, struct_src_i...), |
There was a problem hiding this comment.
This contains a getindex that relies on StridedView producing a view, so if we ever change this behavior, this is where we will have to be careful.
| # 1. Extract: copy each source block into column i of buffer_src as a flat vector, | ||
| # using a trivial permutation so the layout is canonical before the matmul. | ||
| @inbounds for (i, struct_src_i) in enumerate(structs_src) | ||
| TO.tensoradd!( |
There was a problem hiding this comment.
Does it make sense to simply call copy! for the ptriv case?
| p, false, α * coeff, β, backend, allocator | ||
| ) | ||
| else # Multi-tree block: pack → recoupling matmul → unpack. | ||
| rows, cols = size(U) |
There was a problem hiding this comment.
Is there ever a case where this is not square? If so, does it make sense to do the trivial permutation on the largest of the two (src vs dest), and the non-trivial permutation on the smallest?
| p, false, α, β, backend, allocator | ||
| ) | ||
| end | ||
| @lock buffer_lock TO.tensorfree!(buffer, allocator) |
There was a problem hiding this comment.
I have forgotten how the allocators work. It is now fine that we free buffers in on order that is not the exact reverse as the order in which they were allocated?
| # buffers have to be created without race condition: err on the side of caution with a lock | ||
| buffer_lock = Threads.ReentrantLock() | ||
|
|
||
| OhMyThreads.@tasks for src in fusionblocks(tsrc) |
There was a problem hiding this comment.
There is some code duplication between this generic implementation, and the one for GenericTreeTransformer below (hence my questions there also apply here). But I think the code duplication is unavoidable (and anyway quite limited).
| return maximum(transformer.data; init = 0) do (basistransform, structures_dst, _) | ||
| return prod(structures_dst[1]) * size(basistransform, 1) | ||
| end | ||
| end |
There was a problem hiding this comment.
Is this function used anywhere? I couldn't find a single calling instance? The add_transform_kernel! computes the buffer size manually.
Jutho
left a comment
There was a problem hiding this comment.
Ok, this looks great. I've left a few questions, but mostly just to get better understanding.
Summary
This PR overhauls the index manipulation API in
src/tensors/indexmanipulations.jlto match TensorOperations dispatch conventions, reduces code duplication in the implementation, and adds a dedicated documentation page.The goal was a bunch of code simplification, (overall number of lines reduced, even though I added some docs 🎉 )
API changes
permute!,braid!,transpose!, andrepartition!now directly acceptα,β,backend, andallocatoras optional arguments (with defaultsOne(),Zero(),DefaultBackend(),DefaultAllocator()), following the TensorOperations dispatch pattern. The oldadd_permute!,add_braid!, andadd_transpose!are deprecated and forward to the new functions.allocatorsupport: previously, the index manipulation functions did not support a customallocatorat all. It is now a positional argument in both the public and internal interfaces, consistent with TensorOperations convention.permute,braid,transpose,repartition) gainbackendas a new keyword argument alongside the now-supportedallocatorkeyword.Implementation changes
braid!, eliminating duplicate codepaths.braid!method added forAdjointTensorMap.add_transform!kernels forTensorMaprefactored to operate on the raw data vector rather than the fullTensorMap. Because the data vector has no symmetry type, this avoids recompilation for everyTensorMaptype combination, improving compilation time.