-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuildDataSequenceClassification.py
More file actions
85 lines (69 loc) · 2.05 KB
/
buildDataSequenceClassification.py
File metadata and controls
85 lines (69 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
"""
作者: terrychan
Blog: https://terrychan.org
# 说明:
自动构建数据集 预处理使用
SequenceClassification模式数据集
数据参考示例
dataDemo/SequenceClassification.csv
"""
import json
import sys
import pandas as pd
import torch
# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html
from sklearn import preprocessing
from torch.utils.data import random_split, TensorDataset
from config import *
# 输出目录
path = "out"
MAX_LENGTH = 128
le = preprocessing.LabelEncoder()
print("""
seq2seq模式数据集
数据参考示例
dataDemo/SequenceClassification.csv
""")
if len(sys.argv) > 1:
dataFile = sys.argv[1]
else:
dataFile = input("数据集地址:")
if dataFile:
df = pd.read_csv(dataFile)
df.drop_duplicates()
print("数据集格式如下:")
print(df)
dataA = df.iloc[:, [0]].squeeze().astype(str).values.tolist()
dataB = df.iloc[:, [1]].squeeze().astype(str).values.tolist()
le.fit(dataB)
labels = list(le.classes_)
print("labels", labels)
print("labels len:", len(labels))
# 获取标签格式数据
tgt = le.transform(dataB)
# print(tgt)
inputsA = tokenizer(dataA, return_tensors="pt", padding="max_length", max_length=MAX_LENGTH, truncation=True)
tgt = torch.Tensor(tgt)
traindataset = TensorDataset(inputsA['input_ids'], inputsA['attention_mask'],
tgt
)
fullLen = len(traindataset)
trainLen = int(fullLen * 0.7)
valLen = int(fullLen * 0.15)
testLen = fullLen - trainLen - valLen
train, val, test = random_split(traindataset, [trainLen, valLen, testLen])
try:
os.makedirs(path)
except:
pass
with open(path + "/labels.json", 'w', encoding="utf-8") as f:
json.dump(labels, f, ensure_ascii=False)
torch.save(train, path + "/train.pkt")
torch.save(val, path + "/val.pkt")
torch.save(test, path + "/test.pkt")
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
# print_hi('PyCharm')
pass
# See PyCharm help at https://www.jetbrains.com/help/pycharm/