-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbaseline_model.py
More file actions
166 lines (127 loc) · 6.28 KB
/
baseline_model.py
File metadata and controls
166 lines (127 loc) · 6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from transformers import EsmTokenizer
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryAUROC, BinaryMatthewsCorrCoef, \
BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryConfusionMatrix
from collections import defaultdict
class CNNBlock(nn.Module):
"""
CNN models were trained with the following parameters:
number of filters=128, kernel size=37,
pooling size=4, dropout rate=0.2, LSTM output size=50, batch size=128.
"""
def __init__(self,
model_config: dict = None,
** kwargs):
super().__init__()
self.cnn = nn.Conv1d(model_config['num_filters'], model_config['num_filters'], model_config['kernel_size'], stride=1, padding=model_config['kernel_size']//2)
self.activation = nn.ReLU() if model_config['activation'] == 'relu' else nn.Identity()
self.dropout = nn.Dropout(model_config['dropout'])
self.max_pool = nn.MaxPool1d(model_config['max_pool_kernel_size'], stride=model_config['max_pool_kernel_size'])
self.dropout = nn.Dropout(model_config['dropout'])
def forward(self, x):
# x needs to be of shape (batch_size, in_channels, seq_len)
x = self.cnn(x) # (batch_size, out_channels, seq_len)
x = self.activation(x)
x = self.dropout(x)
x = self.max_pool(x)
x = self.dropout(x)
return x
class CNNLSTMPredictor(pl.LightningModule):
def __init__(self,
model_config: dict = None,
train_config: dict = None,
data_config: dict = None,
comment: str = '',
batch_size: int = 2,
**kwargs):
super().__init__()
self.save_hyperparameters()
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.batch_size = batch_size
# input_size, hidden_size, num_layers = 1, bias = True, batch_first = False, dropout = 0.0, bidirectional = False, proj_size = 0, device = None, dtype = None
self.tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t12_35M_UR50D')
self.embedding = nn.Embedding(num_embeddings=self.tokenizer.vocab_size, embedding_dim=model_config['num_filters'], padding_idx=self.tokenizer.pad_token_id)
self.cnn = nn.Sequential(*[CNNBlock(model_config)] * model_config['num_blocks'])
self.lstm = nn.LSTM(input_size=model_config['num_filters'], hidden_size=model_config['lstm_hidden_size'], batch_first=True, )
self.dropout = nn.Dropout(model_config['dropout'])
self.linear = nn.Linear(model_config['lstm_hidden_size'], 1)
self.sigmoid = nn.Sigmoid()
self.model_config = model_config
self.model_config = model_config
self.train_config = train_config
self.data_config = data_config
self.loss = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([train_config['pos_weight']]) if 'pos_weight' in train_config and train_config['pos_weight'] is not None else None)
self.outputs = {'train': defaultdict(list), 'val': defaultdict(list), 'test': defaultdict(list)}
self.evaluation = nn.ModuleDict({'accuracy': BinaryAccuracy(), 'precision': BinaryPrecision(),
'recall': BinaryRecall(), 'f1': BinaryF1Score(), 'auc': BinaryAUROC(),
'aps': BinaryAveragePrecision(), 'mcc': BinaryMatthewsCorrCoef()})
self.confusion_matrix = BinaryConfusionMatrix()
def forward(self, x_original, attention_mask=None, output_weights=False):
x = self.embedding(x_original) # (batch_size, seq_len, in_channels)
x = self.cnn(x.permute(0, 2, 1)) # (batch_size, out_channels, seq_len)
x = x.permute(0, 2, 1) # (batch_size, seq_len, out_channels)
out, (h, c) = self.lstm(x) # (batch_size, seq_len, hidden_size) for x, (num_layers, batch_size, hidden_size) for h and c
x = h[-1] # (batch_size, hidden_size) since using last layer
x = self.dropout(x) # (batch_size, hidden_size)
x = self.linear(x).squeeze(1) # (batch_size)
return x
def predict(self, x, attention_mask=None):
x = self.sigmoid(self(x))
return x
def set_train_dataset(self, train_dataset):
self.train_dataset = train_dataset
def train_dataloader(self):
if self.train_dataset is None:
raise ValueError('train_dataset is None, please first set it using set_train_dataset()')
return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def set_val_dataset(self, val_dataset):
self.val_dataset = val_dataset
def val_dataloader(self):
if self.val_dataset is None:
raise ValueError('val_dataset is None, please first set it using set_val_dataset()')
return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
def set_test_dataset(self, test_dataset):
self.test_dataset = test_dataset
def test_dataloader(self):
if self.test_dataset is None:
raise ValueError('test_dataset is None, please first set it using set_train_dataset()')
return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
def on_train_start(self) -> None:
self.logger.log_hyperparams({'num_params': sum(p.numel() for p in self.parameters()),
'batch_size': self.batch_size,})
def basic_step(self, batch, batch_idx, mode):
seq, attention_mask, y = batch
y_hat_logits = self(seq, attention_mask)
loss = self.loss(y_hat_logits, y)
# log metrics
y_hat = self.sigmoid(y_hat_logits)
self.log(f'{mode}_loss', loss, on_step=False, on_epoch=True, logger=True)
self.outputs[mode]['y_hat'].append(y_hat.detach())
self.outputs[mode]['y'].append(y.detach())
return loss
def basic_epoch_end(self, mode):
outputs = self.outputs[mode]
y_hat = torch.cat(outputs['y_hat'])
y = torch.cat(outputs['y'])
for name, metric in self.evaluation.items():
self.log(f'{mode}_{name}', metric(y_hat, y.int()), on_step=False, on_epoch=True, logger=True,
prog_bar=True if mode == 'val' else False)
self.outputs[mode] = defaultdict(list) # reset outputs to free memory
def training_step(self, batch, batch_idx):
return self.basic_step(batch, batch_idx, 'train')
def on_train_epoch_end(self):
self.basic_epoch_end('train')
def validation_step(self, batch, batch_idx):
return self.basic_step(batch, batch_idx, 'val')
def on_validation_epoch_end(self):
self.basic_epoch_end('val')
def test_step(self, batch, batch_idx):
return self.basic_step(batch, batch_idx, 'test')
def on_test_epoch_end(self):
self.basic_epoch_end('test')
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.train_config['lr'])