Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions src/bispectrum/dn_on_dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,37 @@ def __init__(self, n: int, selective: bool = True) -> None:
self.register_buffer('_idft_cos', cos_table)
self.register_buffer('_idft_sin', sin_table)

# --- Precompute vectorized Fplus gather indices ----------------------
# fhat shape at forward time: (batch, 2, 2, n2d+1)
# Flatten last 3 dims → (batch, 4*(n2d+1)), then pad with one zero.
# _fplus_src[m, i, j] = flat index into padded fhat, or pad_idx for zeros.
fhat_flat_size = 4 * (n2d + 1)
pad_idx = fhat_flat_size # points to the appended zero

fplus_src = torch.full((max(n3, 1), 4, 4), pad_idx, dtype=torch.long)

for m, decomp in enumerate(self._decompositions):
for block in decomp:
if block.block_type == '2d':
r0, r1 = block.rows
k = block.label
assert isinstance(k, int)
# fhat[:, row, col, k] → flat idx = row*2*(n2d+1) + col*(n2d+1) + k
fplus_src[m, r0, r0] = 0 * 2 * (n2d + 1) + 0 * (n2d + 1) + k
fplus_src[m, r0, r1] = 0 * 2 * (n2d + 1) + 1 * (n2d + 1) + k
fplus_src[m, r1, r0] = 1 * 2 * (n2d + 1) + 0 * (n2d + 1) + k
fplus_src[m, r1, r1] = 1 * 2 * (n2d + 1) + 1 * (n2d + 1) + k
else:
r = block.rows[0]
lbl = block.label
assert isinstance(lbl, str)
row, col = {'rho0': (0, 0), 'rho01': (1, 0), 'rho02': (0, 1), 'rho03': (1, 1)}[
lbl
]
fplus_src[m, r, r] = row * 2 * (n2d + 1) + col * (n2d + 1) + 0

self.register_buffer('_fplus_src', fplus_src)

# --- index map ------------------------------------------------------
idx_map: list[tuple[int, ...]] = [(0, 0)]
for r in range(2):
Expand Down Expand Up @@ -421,31 +452,41 @@ def forward(self, f: torch.Tensor) -> torch.Tensor:

fhat = self._group_dft(f)

parts: list[torch.Tensor] = []

# beta_{rho0, rho0} = F(rho0)^3
F_rho0 = fhat[:, 0, 0, 0]
parts.append((F_rho0**3).unsqueeze(-1))
beta_00 = (F_rho0**3).unsqueeze(-1)

# beta_{rho0, rho1} = F(rho0) * F(rho1)^T @ F(rho1) (2x2)
F_rho1 = fhat[:, :, :, 1]
beta_01 = F_rho0[:, None, None] * torch.bmm(F_rho1.transpose(-1, -2), F_rho1)
parts.append(beta_01.reshape(batch, 4))

# beta_{rho1, rho_k} for k = 1..n3 (4x4 each)
for m in range(n3):
k = m + 1
F_k = fhat[:, :, :, k]
fh_kron = _batched_kron_2x2(F_rho1, F_k)

Fplus = self._build_fplus(fhat, m)
C = self._cg_matrices[m].to(dtype)

# beta = C (oplus F^T) C^T (F1 x Fk)
beta_1k = C @ Fplus.transpose(-1, -2) @ C.T @ fh_kron
parts.append(beta_1k.reshape(batch, 16))

return torch.cat(parts, dim=-1)
if n3 == 0:
return torch.cat([beta_00, beta_01.reshape(batch, 4)], dim=-1)

# Batched kron: all n3 products at once
# F_rho1: (batch, 2, 2), F_all_k: (batch, 2, 2, n3)
F_all_k = fhat[:, :, :, 1 : n3 + 1]
all_kron = torch.einsum('bij,bklm->bmikjl', F_rho1, F_all_k).reshape(batch, n3, 4, 4)

# Vectorized Fplus via precomputed gather indices
fhat_flat = fhat.reshape(batch, -1) # (batch, 4*(n2d+1))
fhat_padded = torch.cat(
[fhat_flat, torch.zeros(batch, 1, device=f.device, dtype=dtype)], dim=-1
)
src = self._fplus_src[:n3].reshape(-1) # (n3*16,)
all_fplus = fhat_padded[:, src].reshape(batch, n3, 4, 4)

# Batched matmul: C @ Fplus^T @ C^T @ kron for all k at once
C_all = self._cg_matrices[:n3].unsqueeze(0).to(dtype) # (1, n3, 4, 4)
C_all_T = C_all.transpose(-1, -2)
step1 = torch.matmul(C_all, all_fplus.transpose(-1, -2))
step2 = torch.matmul(step1, C_all_T)
all_beta_1k = torch.matmul(step2, all_kron) # (batch, n3, 4, 4)

return torch.cat(
[beta_00, beta_01.reshape(batch, 4), all_beta_1k.reshape(batch, n3 * 16)],
dim=-1,
)

def invert(self, beta: torch.Tensor, **kwargs: object) -> torch.Tensor:
"""Recover a signal from its selective bispectrum.
Expand Down
22 changes: 16 additions & 6 deletions src/bispectrum/so3_on_s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,34 @@ def _proved_linear_bootstrap_block(l_target: int) -> list[tuple[int, int, int]]:
For each target degree ell >= 8, use:

- X0_a = (a, ell, ell-a), 1 <= a <= ell-1
- X1_a = (a, ell, ell-a+1), 2 <= a <= ell-1
- C_a = (a, ell-a, ell), 1 <= a <= 4

This gives exactly
(ell - 1) + (ell - 2) + 4 = 2*ell + 1
linear equations in F_ell.
- X1_a = (a, ell, ell-a+1), 2 <= a <= ell-1, EXCLUDING a = (ell+1)//2
when ell is odd. The excluded triple is (r, 2*r-1, r) at ell = 2r-1,
which has a repeated index and parity 4r-1 (odd), so it is identically
zero on real signals (Proposition prop:app-odd-vanishing in the paper).
- C_a = (a, ell-a, ell), 1 <= a <= 4 (chain rows).
- Z_ell = {(2, ell-1, ell)} for odd ell, empty for even ell. This single
shifted-chain row compensates for the dropped second-family entry,
preserving the budget |T_ell| = 2*ell + 1. It has all-distinct
indices (admissible, not contained in C since C only covers
pairs (a, ell-a) with a in {1,2,3,4}).
"""
block: list[tuple[int, int, int]] = []

for a in range(1, l_target):
block.append((a, l_target, l_target - a))

skip_a = (l_target + 1) // 2 if l_target % 2 == 1 else None
for a in range(2, l_target):
if a == skip_a:
continue
block.append((a, l_target, l_target - a + 1))

for a in range(1, 5):
block.append((a, l_target - a, l_target))

if l_target % 2 == 1:
block.append((2, l_target - 1, l_target))

return block


Expand Down
Loading
Loading