Skip to content

Commit 0f888c4

Browse files
committed
compute_fejer_weights method. Still need to optimize (very slow)
1 parent 921bec2 commit 0f888c4

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

src/aspire/abinitio/commonline_nug.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
**kwargs,
4242
):
4343
"""
44-
Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry.
44+
Initialize object for estimating 3D orientations for symmetric molecules.
4545
4646
:param src: The source object of 2D denoised or class-averaged images with metadata
4747
:param symmetry: A string, ie. 'C3', indicating the symmetry type.
@@ -763,6 +763,9 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid):
763763
loadmat("data/Fejer/Ngrid=%i" % Ngrid + "/k=%i" % k + ".mat")["W1"]
764764
for k in range(1, Lmax + 1)
765765
]
766+
767+
W0k, W1k = self.compute_fejer_weights()
768+
766769
# AI_mat=np.zeros((Ngrid,D0+D1))
767770
# for p in range(Ngrid):
768771
# w0=np.zeros(D0); w1=np.zeros(D1)
@@ -864,6 +867,47 @@ def ADMM_preprocessing(self, C, Lmax, N, Ngrid):
864867
Sq,
865868
)
866869

870+
def compute_fejer_weights(self):
871+
SO3_grid = loadmat("data/SO3_grid.mat")["SO3"]
872+
Ngrid = SO3_grid.shape[0]
873+
start = 1
874+
875+
TT = []
876+
TTI = []
877+
for ell in range(start, self.Lmax + 1):
878+
T, Tinv = self.complex2real(ell)
879+
TT.append(T)
880+
TTI.append(Tinv)
881+
882+
def permutek_block(Ak, k):
883+
dk = 2 * k + 1
884+
Pk = np.eye(dk)
885+
for m in range(k):
886+
for l in range(k - m):
887+
Pk[(m + 2 * l, m + 2 * l + 1), :] = Pk[
888+
(m + 2 * l + 1, m + 2 * l), :
889+
]
890+
AkP = Pk @ Ak @ Pk.T
891+
return AkP[:k, :k], AkP[k:, k:]
892+
893+
W0 = []
894+
W1 = []
895+
for k in range(start, self.Lmax + 1):
896+
print(k)
897+
W0k = np.zeros((Ngrid, k, k))
898+
W1k = np.zeros((Ngrid, k + 1, k + 1))
899+
900+
TkT = TT[k - start].T
901+
TinvkT = TTI[k - start].T
902+
903+
for p in range(Ngrid):
904+
w = np.real(TkT @ self.WD(k, SO3_grid[p]).conj() @ TinvkT)
905+
W0k[p], W1k[p] = permutek_block(w, k)
906+
907+
W0.append(W0k)
908+
W1.append(W1k)
909+
return W0, W1
910+
867911
#########################
868912
# Euler Estimation Step #
869913
#########################

0 commit comments

Comments
 (0)