-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgenerate_reward_data.py
More file actions
71 lines (56 loc) · 2.34 KB
/
generate_reward_data.py
File metadata and controls
71 lines (56 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import pandas as pd
from tqdm import tqdm
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM,BitsAndBytesConfig
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hf_token = "HFTOKEN"
model_id = "vicgalle/gpt2-open-instruct-v1"
quantization_config = BitsAndBytesConfig(
load_in_8bit=True
)
def get_completion(query_list, model, tokenizer) -> str:
prompt_template = """
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Pretend you are a medical expert and answer to the following question - {query}
### Response:
"""
prompt_list = []
for query in query_list:
prompt = prompt_template.format(query=query)
prompt_list.append(prompt)
encodeds = tokenizer(prompt_list, return_tensors="pt", max_length=1024, padding=True, truncation=True)
model_inputs = encodeds.to(device)
generated_ids = model.generate(**model_inputs, max_length=250, do_sample=True, pad_token_id=tokenizer.eos_token_id, num_beams=1, num_return_sequences=1)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return decoded
tokenizer = AutoTokenizer.from_pretrained(model_id, token= hf_token, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
token= hf_token,
torch_dtype=torch.float16,
)
def generate_diagnosis_given_text(df):
diagnosis_list = []
progress_bar = tqdm(range(df.shape[0]//batch_size))
query_list = []
for i, row in df.iterrows():
if ((i + 1) % batch_size == 0) or (i == df.shape[0] - 1):
query_list.append(row['Question'])
results = get_completion(query_list, model=model, tokenizer=tokenizer)
query_list = []
diagnosis_list.extend(result.split('Response')[1] for result in results)
progress_bar.update()
torch.cuda.empty_cache()
gc.collect()
else:
query_list.append(row['Question'])
return diagnosis_list
batch_size = 128
if __name__ == '__main__':
df = pd.read_csv('filtered_train.csv')
diagnosis_list = generate_diagnosis_given_text(df)
df['generated'] = diagnosis_list
df.to_csv('reward_dataset.csv')