-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathinference.py
More file actions
68 lines (60 loc) · 2.59 KB
/
inference.py
File metadata and controls
68 lines (60 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Script to generate inference ofr a given csv file
from fast_bert.prediction import BertClassificationPredictor
import argparse
import csv
import pandas as pd
import os
from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer
from pprint import pprint
# run inference on the csv file provided using the trained model
def run(model,csvs, threshold, evaluation):
labels = ["anger", "anticipation","disgust","fear","joy","love","optimism","pessimism","sadness","surprise","trust","neutral"]
predictor = BertClassificationPredictor(
model_path=args.model_dir,
label_path="D:\\UTD\\Assignment\\NLP\\project\\", # location for labels.csv file
multi_label=False,
model_type='bert',
do_lower_case=False)
inputs = {}
ids = []
data = pd.read_csv(csvs)
# print(data.head())
for idx, row in data.iterrows():
temp = []
for label in labels:
if row[label] == 1:
temp.append(label)
inputs[row['text']] = temp
ids.append(row['id'])
multiple_predictions = predictor.predict_batch(list(inputs.keys()))
outputs = []
out_file = open(os.path.join(os.path.dirname(csvs),"model_output.csv"), "w", encoding="utf-8", newline="")
csv_writer = csv.writer(out_file)
csv_writer.writerow(["id","text", "emotions", "target"])
for i, out in enumerate(multiple_predictions):
temp = []
for emotion in out:
if emotion[1] > threshold: # greater than threshold
temp.append(emotion[0])
csv_writer.writerow([ids[i],list(inputs.keys())[i],temp,list(inputs.values())[i] ])
outputs.append(temp)
print("****************\n")
print("Predictions saved in a file: ", os.path.join(os.path.dirname(csvs),"model_output.csv"))
if evaluation:
print("\n\n Running Model Evaluation\n")
y_true = list(inputs.values())
y_pred = outputs
y_true_encoded = MultiLabelBinarizer().fit_transform(y_true)
y_pred_encoded = MultiLabelBinarizer().fit_transform(y_pred)
pprint(classification_report(y_true_encoded, y_pred_encoded))
pprint(classification_report(y_true_encoded, y_pred_encoded, target_names=labels))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir",default="D:\\UTD\\Assignment\\NLP\\project\\model_output\\3_finetune_e20", help="path to output dir")
parser.add_argument("--test_csv", default="D:\\UTD\\Assignment\\NLP\\project\\nlp_test.csv")
parser.add_argument("--threshold", default=0.0017, type=float)
parser.add_argument("--writeto_file", default=True)
parser.add_argument("--evaluation", default=True)
args = parser.parse_args()
run(args.model_dir, args.test_csv, args.threshold, args.evaluation)