Skip to content

Commit 494e09c

Browse files
committed
removing a bug with none normalization mode and reverting to L1 normalization
1 parent aab7b04 commit 494e09c

2 files changed

Lines changed: 19 additions & 7 deletions

File tree

torch_harmonics/attention/attention.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
class AttentionS2(nn.Module):
4747
"""
4848
(Global) attention on the 2-sphere.
49+
4950
Parameters
5051
-----------
5152
in_channels: int
@@ -66,6 +67,13 @@ class AttentionS2(nn.Module):
6667
number of dimensions for interior inner product in the attention matrix (corresponds to kdim in MHA in PyTorch)
6768
out_channels: int, optional
6869
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
70+
71+
Reference
72+
---------
73+
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
74+
"Attention on the Sphere."
75+
Advances in Neural Information Processing Systems (NeurIPS).
76+
https://arxiv.org/abs/2505.11157
6977
"""
7078

7179
def __init__(
@@ -209,6 +217,13 @@ class NeighborhoodAttentionS2(nn.Module):
209217
number of dimensions for interior inner product in the attention matrix (corresponds to vdim in MHA in PyTorch)
210218
optimized_kernel: Optional[bool]
211219
Whether to use the optimized kernel (if available)
220+
221+
Reference
222+
---------
223+
Bonev, B., Rietmann, M., Paris, A., Carpentieri, A., & Kurth, T. (2025).
224+
"Attention on the Sphere."
225+
Advances in Neural Information Processing Systems (NeurIPS).
226+
https://arxiv.org/abs/2505.11157
212227
"""
213228

214229
def __init__(

torch_harmonics/disco/convolution.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def _normalize_convolution_tensor_s2(
8989
If basis_norm_mode is not one of the supported modes.
9090
"""
9191

92-
# exit here if no normalization is needed
93-
if basis_norm_mode == "none":
94-
return psi_vals
95-
9692
# reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
9793
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
9894

@@ -128,9 +124,10 @@ def _normalize_convolution_tensor_s2(
128124
# compute the 1-norm
129125
# scale[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
130126
if basis_norm_mode == "modal":
131-
if ik != 0:
132-
bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx])
133-
scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx]))
127+
# if ik != 0:
128+
# bias[ik, ilat] = torch.sum(psi_vals[iidx] * q[iidx])
129+
# scale[ik, ilat] = torch.sqrt(torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs().pow(2) * q[iidx]))
130+
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])
134131
else:
135132
scale[ik, ilat] = torch.sum((psi_vals[iidx] - bias[ik, ilat]).abs() * q[iidx])
136133

0 commit comments

Comments
 (0)