-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
98 lines (72 loc) · 3.12 KB
/
predict.py
File metadata and controls
98 lines (72 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import time
import warnings
import argparse
warnings.filterwarnings("ignore")
import en_core_web_lg
from datetime import datetime, timedelta
from transformers import pipeline
from constants import base_prompts_list, agent_starter_dialogue
def check_entities_in_conversation(model, transcript):
if transcript.startswith(agent_starter_dialogue):
transcript = transcript[len(agent_starter_dialogue):]
doc = model(transcript)
entity_types = [X.label_ for X in doc.ents]
if list(set(entity_types) & set(['DATE', 'CARDINAL'])) == []:
return False
return True
def get_predictions(model, model_entity_rec, conversation, date, label_type):
base_prompt = base_prompts_list[label_type]
prompt = conversation + "\nToday's Date: " + date + '\n' + base_prompt
if check_entities_in_conversation(model_entity_rec, conversation):
res = model(prompt)
else:
res = [{'generated_text': 'NA'}]
return res
def add_days_to_date(res, date):
num_days = res[0]['generated_text']
temp_res = ""
if str(num_days) == '0':
temp_res = 'NA'
else:
temp_res = datetime.strptime(date.split(',')[0], '%Y-%m-%d') + timedelta(days=int(num_days))
temp_res = temp_res.strftime('%Y-%m-%d')
res[0]['generated_text'] = temp_res
return res
def get_model_params(label_type):
# Model params
model_id = "google/flan-t5-large"
model_name_on_hub = "Salient_ai" + model_id.split("/")[1] + "_" + label_type
model_path = "pratt3000/" + model_name_on_hub
return model_path
def predict(conversation, date, label_type):
model_path = get_model_params(label_type)
print(f"Loading model from {model_path}")
model_entity_rec = en_core_web_lg.load()
model_nlp = pipeline(model = model_path)
print("Running Prediction Engine")
start = time.time()
res = get_predictions(model_nlp, model_entity_rec, conversation, date, label_type = label_type )
res = add_days_to_date(res, date) if label_type == "days_diff" else res
print(f"Time taken (CPU): {round(time.time() - start, 3)} seconds\n")
return res[0]['generated_text']
if __name__ == "__main__":
# argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--conversation', type=str, default="Agent: Can you advise when you'll manage to make the payment? \nCustomer: I should be able to do it on Thursday next week.")
parser.add_argument('--date', type=str, default="2022-01-01, Saturday")
parser.add_argument('--label_type', type=str, default="days_diff")
parser.add_argument('--ensemble', action='store_true')
args = parser.parse_args()
if args.ensemble:
pred_1 = predict(args.conversation, args.date, "label")
pred_2 = predict(args.conversation, args.date, "days_diff")
if pred_1 == pred_2:
pred = pred_1 + "(Confidence: HIGH)"
else:
pred = pred_1 + " or " + pred_2 + "(Confidence: LOW)"
else:
pred = predict(args.conversation, args.date, args.label_type)
print("DATE: ", args.date)
print("CONVERSATION: ")
print(args.conversation)
print(f"RESULT: {pred}")