forked from AI4PDLab/ProtRL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseq_gen.py
More file actions
128 lines (98 loc) · 4.71 KB
/
seq_gen.py
File metadata and controls
128 lines (98 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from transformers import GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM
import os
from tqdm import tqdm
import math
import argparse
def remove_characters(sequence, char_list):
"This function removes special tokens used during training."
columns = sequence.split('<sep>')
seq = columns[1]
for char in char_list:
seq = seq.replace(char, '')
return seq
def calculatePerplexity(input_ids,model,tokenizer):
"This function computes perplexities for the generated sequences"
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
return math.exp(loss)
def calculateloglikelihood(input_ids, model, ref_model, tokenizer):
"This function computes perplexities for the generated sequences"
with torch.no_grad():
outputs_model = model(input_ids, labels=input_ids)
outputs_ref_model = ref_model(input_ids, labels=input_ids)
loss, logits = outputs_model[:2]
ref_loss, logits=outputs_ref_model[:2]
i_reward = -(loss - ref_loss)
return i_reward
def main(label, model,special_tokens,device,tokenizer):
# Generating sequences
input_ids = tokenizer.encode(label,return_tensors='pt').to(device)
outputs = model.generate(
input_ids,
top_k=9, #tbd
repetition_penalty=1.2,
max_length=1014,
eos_token_id=1,
pad_token_id=0,
do_sample=True,
num_return_sequences=20) # Depending non your GPU, you'll be able to generate fewer or more sequences. This runs in an A40.
# Check sequence sanity, ensure sequences are not-truncated.
# The model will truncate sequences longer than the specified max_length (1024 above). We want to avoid those sequences.
new_outputs = [ output for output in outputs if output[-1] == 0]
if not new_outputs:
print("not enough sequences with short lengths!!")
# Compute perplexity for every generated sequence in the batch
ppls = [(tokenizer.decode(output), calculatePerplexity(output, model, tokenizer), calculateloglikelihood(output, model, ref_model, tokenizer)) for output in new_outputs ]
# Sort the batch by perplexity, the lower the better
ppls.sort(key=lambda i:i[1]) # duplicated sequences?
# Final dictionary with the results
sequences={}
sequences[label] = [(remove_characters(x[0], special_tokens), x[1], x[2]) for x in ppls]
return sequences
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--iteration_num", type=int)
parser.add_argument("--label", type=str)
args = parser.parse_args()
iteration_num = args.iteration_num
ec_label = args.label
labels = [ec_label.strip()]
device = torch.device("cuda") # Replace with 'cpu' if you don't have a GPU - but it will be slow
print('Reading pretrained model and tokenizer')
if iteration_num == 0:
model_name = 'AI4PD/ZymCTRL'
else:
model_name = f'./output_iteration{iteration_num}'
if iteration_num == 0:
model_name = args.model_dir
print(f'Model {model_name} has been loaded')
tokenizer = AutoTokenizer.from_pretrained(args.model_dir) # change to ZymCTRL location
model = GPT2LMHeadModel.from_pretrained(model_name).to(device) # change to ZymCTRL location
special_tokens = ['<start>', '<end>', '<|endoftext|>','<pad>',' ', '<sep>']
ref_model = GPT2LMHeadModel.from_pretrained("AI4PD/ZymCTRL").to(device) # change to ZymCTRL location
label = ec_label
canonical_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # Set of canonical amino acids
for label in tqdm(labels):
all_sequences = []
for i in range(10):
sequences = main(label, model, special_tokens, device, tokenizer)
for key, value in sequences.items():
for index, val in enumerate(value):
if all(char in canonical_amino_acids for char in val[0]):
sequence_info = {
'label': label,
'batch': i,
'index': index,
'pepr': float(val[1]),
'fasta': f">{label}_{i}_{index}\t{val[1]}\t{val[2]}\n{val[0]}\n"
}
all_sequences.append(sequence_info)
#all_sequences.sort(key=lambda x: x['pepr'])
#top_sequences = all_sequences[:20] #get the top 20
fasta_content = ''.join(seq['fasta'] for seq in all_sequences)
output_filename = f"seq_gen_{label}_iteration{iteration_num}.fasta"
print(fasta_content)
with open(output_filename, "w") as fn:
fn.write(fasta_content)