-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_base_qg_model.py
More file actions
62 lines (51 loc) · 2.16 KB
/
eval_base_qg_model.py
File metadata and controls
62 lines (51 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import ProphetNetForConditionalGeneration, ProphetNetTokenizer
QG_MODEL = "microsoft/prophetnet-large-uncased-squad-qg"
def main(args):
# set random seed
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch_device = args.device
qg_tokenizer = ProphetNetTokenizer.from_pretrained(QG_MODEL)
qg_model = ProphetNetForConditionalGeneration.from_pretrained(QG_MODEL)
qg_model = qg_model.to(torch_device)
qg_model.eval()
df = pd.read_csv(args.input, converters={"context": str, "question": str, "answer": str})
dataset = [f"{row['answer']} [SEP] {row['context']}" for _, row in df.iterrows()]
dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, sampler=None, drop_last=False)
all_questions = []
for data in tqdm(dataloader, total=len(dataloader)):
qg_inputs = qg_tokenizer(
data,
return_tensors="pt",
truncation=True,
max_length=300,
padding=True,
).to(torch_device)
question_ids = qg_model.generate(
input_ids=qg_inputs["input_ids"],
attention_mask=qg_inputs["attention_mask"],
max_length=120,
num_beams=10,
num_return_sequences=2,
top_p=args.top_p,
)
questions = qg_tokenizer.batch_decode(question_ids, skip_special_tokens=True)
all_questions.extend(questions)
df["gen_question1"] = all_questions[::2]
df["gen_question2"] = all_questions[1::2]
df.to_csv(args.output, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="input file")
parser.add_argument("--output", type=str, help="Output file")
parser.add_argument("--device", default="cuda:1", type=str, help="Device to run on")
parser.add_argument("--batch_size", default=32, type=int, help="batch size")
parser.add_argument("--top_p", default=0.95, type=float, help="top_p")
args = parser.parse_args()
main(args)