-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathplm_wrapper.py
More file actions
170 lines (157 loc) · 7.8 KB
/
plm_wrapper.py
File metadata and controls
170 lines (157 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from utils import *
from transformers import AutoModel, AutoTokenizer, EsmModel, AutoModelForMaskedLM, T5Tokenizer, T5ForConditionalGeneration, AutoConfig, T5EncoderModel, EsmForMaskedLM
from typing import List, Optional, Tuple, Union
MODELS = {
'prottrans-half': 'Rostlab/prot_t5_xl_half_uniref50-enc',
'prottrans': 'Rostlab/prot_t5_xl_uniref50',
'protbert': 'Rostlab/prot_bert',
'esm2-35M': 'facebook/esm2_t12_35M_UR50D',
'esm2-150M': 'facebook/esm2_t30_150M_UR50D',
'esm2-650M': 'facebook/esm2_t33_650M_UR50D'
}
class PLMWrapper(nn.Module):
def __init__(self, plm, **kwargs):
super(PLMWrapper, self).__init__()
self.model_name = plm
if plm in MODELS:
if 'prottrans' in plm:
self.plm = T5ForConditionalGeneration.from_pretrained(MODELS[plm])
self.tokenizer = T5Tokenizer.from_pretrained(MODELS[plm])
elif plm == 'protbert':
self.plm = AutoModelForMaskedLM.from_pretrained(MODELS[plm])
self.tokenizer = AutoTokenizer.from_pretrained(MODELS[plm])
elif 'esm2' in plm:
self.plm = EsmForMaskedLM.from_pretrained(MODELS[plm])
self.tokenizer = AutoTokenizer.from_pretrained(MODELS[plm])
else:
raise ValueError(f'PLM model {plm} not supported.')
else:
try:
if 'prot_t5' in plm:
self.plm = T5ForConditionalGeneration.from_pretrained(plm)
self.tokenizer = T5Tokenizer.from_pretrained(plm)
elif "esm2" in plm:
self.plm = EsmForMaskedLM.from_pretrained(plm)
self.tokenizer = AutoTokenizer.from_pretrained(plm)
else:
self.plm = AutoModelForMaskedLM.from_pretrained(plm)
self.tokenizer = AutoTokenizer.from_pretrained(plm)
except:
raise ValueError(f'PLM model {plm} not supported.')
self.hidden_size = self.plm.config.hidden_size
self.is_t5 = getattr(self.plm.config, "model_type", "") == "t5"
self.plm.gradient_checkpointing_enable()
if kwargs.get('freeze_backbone', True):
self.plm.eval()
freeze_module(self.plm)
# Take input sequence, return input ids and attention_mask
def tokenize(self, input_seqs):
input_seqs = [' '.join(s) for s in input_seqs]
# Allow using '!' as a stand-in for the tokenizer's pad token.
pad_tok = self.tokenizer.pad_token
if pad_tok is None:
# Fallback: treat '!' as a normal residue (X) when no pad token exists.
pad_tok = "X"
input_seqs = [s.replace('!', pad_tok) for s in input_seqs]
tokenized_seqs = self.tokenizer(
input_seqs,
add_special_tokens=True,
padding="longest",
return_tensors="pt",
)
input_ids = tokenized_seqs["input_ids"]
attn_mask = tokenized_seqs["attention_mask"]
return input_ids.to(self.plm.device), attn_mask.to(self.plm.device)
def decode(self, hidden_state):
logits = self.plm.lm_head(hidden_state * (self.plm.model_dim ** -0.5))
seqs = self.tokenizer.batch_decode(logits.argmax(dim=-1), skip_special_tokens=True)
return [s.replace(' ', '') for s in seqs]
def forward_small_batch(self, input_ids, attn_mask, **kwargs):
outputs = self.plm(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, output_hidden_states=True, return_dict=True)
hs = outputs.encoder_hidden_states if self.is_t5 else outputs.hidden_states
retrieve_hs = kwargs.get('num_hidden_states', 1)
if isinstance(retrieve_hs, list):
# prottrans dont have CLS so we shave it away from esm embeddings (not required for alignment)
embeddings = [hs[i][:, 0 if self.is_t5 else 1:].to('cpu') for i in retrieve_hs]
embeddings = torch.cat(embeddings, dim=-1)
else:
# prottrans dont have CLS so we shave it away from esm embeddings (not required for alignment)
assert isinstance(retrieve_hs, int), 'num_hidden_states must be either list or int'
embeddings = [hs[i][:, 0 if self.is_t5 else 1:].to('cpu') for i in range(len(hs) - kwargs.get('num_hidden_states', 1), len(hs))]
embeddings = torch.cat(embeddings, dim=-1)
# prottrans dont have CLS so we shave it away from esm embeddings (not required for alignment)
logits = outputs.logits[:, 0 if self.is_t5 else 1:].to('cpu')
del outputs
torch.cuda.empty_cache()
return embeddings, logits
def forward(self, seqs, **kwargs):
batch = kwargs.get('batch', None)
maxlen = kwargs.get('maxlen', 1022)
overlap = kwargs.get('overlap', 750)
tiles = []
sid_to_tid = defaultdict(list)
for i, s in enumerate(seqs):
if len(s) < maxlen:
tiles.append(s)
sid_to_tid[i].append(len(tiles) - 1)
else:
start = 0
while start < len(s):
tiles.append(s[start: min(len(s), start + maxlen)])
sid_to_tid[i].append(len(tiles) - 1)
start += overlap
input_ids, attn_mask = self.tokenize(tiles) if kwargs.get('tokenize', True) else tiles
if batch is None:
enc_embeddings, logits = self.forward_small_batch(input_ids, attn_mask, **kwargs)
else:
enc_embeddings, logits = [], []
for j in trange(0, len(tiles), batch):
batch_end = min(j + batch, len(tiles))
ij, aj, = input_ids[j: batch_end], attn_mask[j: batch_end]
ej, lj = self.forward_small_batch(ij, aj, **kwargs)
enc_embeddings.append(ej)
logits.append(lj)
enc_embeddings = torch.cat(enc_embeddings, dim=0)
logits = torch.cat(logits, dim=0)
# merge tiles, strip eos
final_embeddings = []
final_logits = []
for i in range(len(seqs)):
if len(sid_to_tid[i]) == 1:
tid = sid_to_tid[i][0]
tile_length = len(tiles[tid])
final_embeddings.append(enc_embeddings[tid][:tile_length])
final_logits.append(logits[tid][:tile_length])
else:
device = enc_embeddings.device
logit = torch.zeros(len(seqs[i]), logits.shape[-1], device=device)
embedding = torch.zeros(len(seqs[i]), enc_embeddings.shape[-1], device=device)
num_tiles = torch.zeros(len(seqs[i]), device=device)
for j, tid in enumerate(sid_to_tid[i]):
tile_length = len(tiles[tid])
start_idx = j * overlap
end_idx = start_idx + tile_length
logit[start_idx: end_idx] += logits[tid][:tile_length].to(device)
embedding[start_idx: end_idx] += enc_embeddings[tid][:tile_length].to(device)
num_tiles[start_idx: end_idx] += 1
embedding = embedding / num_tiles[:, None]
logit = logit/ num_tiles[:, None]
final_embeddings.append(embedding.to('cpu'))
final_logits.append(logit.to('cpu'))
del embedding, logit, num_tiles
torch.cuda.empty_cache()
del enc_embeddings, logits
torch.cuda.empty_cache()
return final_embeddings, final_logits
if __name__ == '__main__':
from msa_tools import *
set_seed(123)
plm = PLMWrapper('esm2-35M').to('cuda')
seqs = [
random_protein_sequence(length=random.randint(500, 750))
for i in range(5)
]
seqs = polypad_paddings(seqs, w=4)
embeddings, logits = plm(seqs)
pprint([len(s) for s in seqs])
pprint([e.shape for e in embeddings])