-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenization.py
More file actions
61 lines (48 loc) · 1.9 KB
/
tokenization.py
File metadata and controls
61 lines (48 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import BertTokenizer
import pandas as pd
import torch
## initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
## tokenization param
default_max_len = 128
## tokenizer
def tokenize(texts: list[str], max_len: int=default_max_len, tokenizer: BertTokenizer=tokenizer):
'''
takes in a list of texts and tokenizes them with the input tokenizer
this defaults to the bert-base-uncased tokenizer from the transformers library
Args:
texts (list of str): texts to tokenize
max_len (int): max length of the tokenized texts
tokenizer (BertTokenizer): bert tokenizer
Returns:
dict: dictionary containing the input ids and attention masks
'''
## initialize lists for input ids and attention masks
input_ids = []
attention_masks = []
## for every text in the dataset, use the tokenizer to encode it
for text in texts:
encoded_dict = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_len,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
## append the encoded text and attention mask to their respective lists
input_ids.append(encoded_dict['input_ids'])
attention_masks.append(encoded_dict['attention_mask'])
## concatenate the lists of input ids and attention masks
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
return {'input_ids': input_ids, 'attention_masks': attention_masks}
## load the dataset
df = pd.read_pickle('data/formatted_data.pkl')
## tokenize the texts
tokenized_texts = tokenize(df['text'].tolist())
## add encoded labels
tokenized_texts['labels'] = torch.tensor(df['final_label'].values)
## export encoded text
torch.save(tokenized_texts, 'data/text.pt')