Skip to content

Commit d8a12de

Browse files
committed
removing a bug with none normalization mode and reverting to L1 normalization
1 parent 7377f7a commit d8a12de

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

torch_harmonics/attention/attention.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
class AttentionS2(nn.Module):
4848
"""
4949
(Global) attention on the 2-sphere.
50+
5051
Parameters
5152
-----------
5253
in_channels: int
@@ -67,6 +68,13 @@ class AttentionS2(nn.Module):
6768
number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch)
6869
out_channels: int, optional
6970
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
71+
72+
Reference
73+
---------
74+
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
75+
"Attention on the Sphere."
76+
Advances in Neural Information Processing Systems (NeurIPS).
77+
https://arxiv.org/abs/2505.11157
7078
"""
7179

7280
def __init__(
@@ -210,6 +218,13 @@ class NeighborhoodAttentionS2(nn.Module):
210218
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
211219
optimized_kernel: Optional[bool]
212220
Whether to use the optimized kernel (if available)
221+
222+
Reference
223+
---------
224+
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
225+
"Attention on the Sphere."
226+
Advances in Neural Information Processing Systems (NeurIPS).
227+
https://arxiv.org/abs/2505.11157
213228
"""
214229

215230
def __init__(

torch_harmonics/disco/convolution.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,6 @@ def _normalize_convolution_tensor_s2(
9292
If basis_norm_mode is not one of the supported modes.
9393
"""
9494

95-
# exit here if no normalization is needed
96-
if basis_norm_mode == "none":
97-
return psi_vals
98-
9995
# reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
10096
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
10197

@@ -131,9 +127,10 @@ def _normalize_convolution_tensor_s2(
131127
# compute the 1-norm
132128
# scale[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
133129
if basis_norm_mode == "modal":
134-
if ik != 0:
135-
bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx])
136-
scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx]))
130+
# if ik != 0:
131+
# bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx])
132+
# scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx]))
133+
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])
137134
else:
138135
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])
139136

0 commit comments

Comments
 (0)