-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path0-split_data.py
More file actions
77 lines (65 loc) · 2.85 KB
/
0-split_data.py
File metadata and controls
77 lines (65 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
from dataset import *
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'datasets',
nargs='+',
)
parser.add_argument('-s', '--seed', type=int, default=42)
args = parser.parse_args()
Path("results/").mkdir(exist_ok=True)
for dataset_name in args.datasets:
# ----------------------- 1. Select dataset -----------------------
np.random.seed(args.seed)
# -------- AQuA --------
if dataset_name=="aqua":
# Because the dev and test sets are so small, we randomly split 70/15/15 the train set
dataset = AQuADataset("datasets/aqua")
I = np.random.permutation(len(dataset.train_data))
I_train = I[:int(len(I)*0.7)]
I_valid = I[int(len(I)*0.7):int(len(I)*0.85)]
I_test = I[int(len(I)*0.85):]
# -------- CosmosQA --------
elif dataset_name=="cosmosqa":
# Because the test set doesn't have answers, we randomly split 80/20 the train set and use the 'valid' set as test set
dataset = CosmosQADataset("datasets/cosmosqa")
I = np.random.permutation(len(dataset.train_data))
I_train = I[:int(len(I)*0.8)]
I_valid = I[int(len(I)*0.8):]
I_test = None
# -------- MMLU --------
elif dataset_name=="mmlu":
I_train = None
I_valid = None
I_test = None
# -------- MedMCQA --------
elif dataset_name=="medmcqa":
# Because the dev and test sets are so small, we randomly split 70/15/15 the train set
dataset = MedMCQADataset("datasets/medmcqa")
I = np.random.permutation(len(dataset.train_data))
I_train = I[:int(len(I)*0.7)]
I_valid = I[int(len(I)*0.7):int(len(I)*0.85)]
I_test = I[int(len(I)*0.85):]
# -------- HellaSwag --------
elif dataset_name=="hellaswag":
# Because the test set doesn't have answers, we randomly split 80/20 the train set and use the 'valid' set as test set
dataset = HellaSwagDataset("datasets/hellaswag")
I = np.random.permutation(len(dataset.train_data))
I_train = I[:int(len(I)*0.8)]
I_valid = I[int(len(I)*0.8):]
I_test = None
# -------- LogiQA --------
elif dataset_name=="logiqa":
I_train = None
I_valid = None
I_test = None
# -------- Unrecognized --------
else:
raise ValueError(f"Unknown dataset '{dataset_name}'")
split_path = Path(f"datasets/{dataset_name}/split.pkl")
with split_path.open("wb") as file:
pickle.dump((I_train, I_valid, I_test), file)