-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodel.py
More file actions
158 lines (133 loc) · 5.3 KB
/
model.py
File metadata and controls
158 lines (133 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
# Create a matrix of shape (seq_len, d_model)
pe = torch.zeros(seq_len, d_model)
# Create a vector of shape (seq_len, 1)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
1
) # (Seq_len, 1)
# Compute the positional encodings once in log space.
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply the sin to even positions
pe[:, 0::2] = torch.sin(position * div_term)
# Apply the cos to odd positions
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
self.register_buffer("pe", pe)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
return self.dropout(x)
# class LigandGNN(nn.Module): # GCN CONV
# def __init__(self, input_dim, hidden_channels):
# super().__init__()
# self.hidden_channels = hidden_channels
#
# self.conv1 = GCNConv(input_dim, hidden_channels)
# self.conv2 = GCNConv(hidden_channels, hidden_channels)
# self.conv3 = GCNConv(hidden_channels, hidden_channels)
# self.dropout = nn.Dropout(0.2)
#
# def forward(self, x, edge_index, batch):
# x = self.conv1(x, edge_index)
# x = x.relu()
# x = self.dropout(x)
#
# x = self.conv2(x, edge_index)
# x = x.relu()
# x = self.conv3(x, edge_index)
# x = self.dropout(x)
#
# # Averaging nodes and got the molecula vector
# x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# return x
class LigandGNN(nn.Module):
def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
super().__init__()
# Heads=4 means we use 4 attention heads
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
self.conv2 = GATConv(
hidden_channels, hidden_channels, heads=heads, concat=False
)
self.conv3 = GATConv(
hidden_channels, hidden_channels, heads=heads, concat=False
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.dropout(x)
x = self.conv2(x, edge_index)
x = x.relu()
x = self.dropout(x)
x = self.conv3(x, edge_index)
# Global Mean Pooling
x = global_mean_pool(x, batch)
return x
class ProteinTransformer(nn.Module):
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=h, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
self.fc = nn.Linear(d_model, output_dim)
def forward(self, x):
# x: [batch_size, seq_len]
padding_mask = x == 0 # mask for PAD tokens
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
x = self.transformer(x, src_key_padding_mask=padding_mask)
mask = (~padding_mask).float().unsqueeze(-1)
x = x * mask
sum_x = x.sum(dim=1) # Global average pooling
token_counts = mask.sum(dim=1).clamp(min=1e-9)
x = sum_x / token_counts
x = self.fc(x)
return x
class BindingAffinityModel(nn.Module):
def __init__(
self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
):
super().__init__()
# Tower 1 - Ligand GNN
self.ligand_gnn = LigandGNN(
input_dim=num_node_features,
hidden_channels=hidden_channels,
heads=gat_heads,
dropout=dropout,
)
# Tower 2 - Protein Transformer
self.protein_transformer = ProteinTransformer(
vocab_size=26,
d_model=hidden_channels,
output_dim=hidden_channels,
dropout=dropout,
)
self.head = nn.Sequential(
nn.Linear(hidden_channels * 2, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, 1),
)
def forward(self, x, edge_index, batch, protein_seq):
ligand_vec = self.ligand_gnn(x, edge_index, batch)
batch_size = batch.max().item() + 1
protein_seq = protein_seq.view(batch_size, -1)
protein_vec = self.protein_transformer(protein_seq)
combined = torch.cat([ligand_vec, protein_vec], dim=1)
return self.head(combined)