-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspaceencoding.py
More file actions
39 lines (32 loc) · 957 Bytes
/
spaceencoding.py
File metadata and controls
39 lines (32 loc) · 957 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
import torch as th
class SpatialEncoder(nn.Module):
"""
Spatial Encoder for encoding shortest path distances.
Args:
max_dist (int): Maximum distance for the shortest path.
num_heads (int): Number of attention heads.
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embedding_table = nn.Embedding(
cfg.max_num_nodes + 2, cfg.n_heads, padding_idx=0, device=cfg.device
)
def forward(self, dist):
"""
Forward pass for the spatial encoder.
Args:
dist (Tensor): Shortest path distance tensor.
Returns:
Tensor: Spatial encoding tensor.
"""
spatial_encoding = self.embedding_table(
th.clamp(
dist,
min=-1,
max=self.cfg.max_dist,
)
+ 1
)
return spatial_encoding