-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathFastTokenizers.py
More file actions
207 lines (174 loc) · 9.21 KB
/
FastTokenizers.py
File metadata and controls
207 lines (174 loc) · 9.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import is_tf_available,is_torch_available
from transformers.tokenization_bert import VOCAB_FILES_NAMES,PRETRAINED_VOCAB_FILES_MAP,PRETRAINED_INIT_CONFIGURATION,PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
from transformers import tokenization_distilbert as tdb
import os
import logging
logger = logging.getLogger(__name__)
import six
if is_tf_available():
import tensorflow as tf
if is_torch_available():
import torch
class FastPreTrainedTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs):
super(FastPreTrainedTokenizer, self).__init__(**kwargs)
@property
def tokenizer(self):
if self._tokenizer is None:
raise NotImplementedError
return self._tokenizer
@property
def decoder(self):
if self._decoder is None:
raise NotImplementedError
return self._decoder
@property
def vocab_size(self):
return self.tokenizer._tokenizer.get_vocab_size()
def __len__(self):
return self.tokenizer._tokenizer.get_vocab_size()
def _update_special_tokens(self):
self.tokenizer.add_special_tokens(self.all_special_tokens)
@staticmethod
def _convert_encoding(encoding,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
encoding_dict = {
"input_ids": encoding.ids,
}
if return_token_type_ids:
encoding_dict["token_type_ids"] = encoding.type_ids
if return_attention_mask:
encoding_dict["attention_mask"] = encoding.attention_mask
if return_overflowing_tokens:
overflowing = encoding.overflowing
encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else []
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
# Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available():
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
elif return_tensors == 'pt' and is_torch_available():
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors))
return encoding_dict
def encode_plus(self,
text,
text_pair=None,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs):
encoding = self.tokenizer.encode(text, text_pair)
return self._convert_encoding(encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
def tokenize(self, text):
return self.tokenizer.encode(text).tokens
def _convert_token_to_id_with_added_voc(self, token):
return self.tokenizer.token_to_id(token)
def _convert_id_to_token(self, index):
return self.tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens):
return self.decoder.decode(tokens)
def add_tokens(self, new_tokens):
self.tokenizer.add_tokens(new_tokens)
def encode_batch(self, texts,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
return [self._convert_encoding(encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
for encoding in self.tokenizer.encode_batch(texts)]
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = self.tokenizer.decode(token_ids, skip_special_tokens)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
return [self.clean_up_tokenization(text)
if clear_up_tokenization_spaces else text
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)]
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, self.vocab_files_names['vocab_file'])
else:
vocab_file = vocab_path
self.tokenizer._tokenizer.model.save(vocab_path,"")
return (vocab_file,)
class BertTokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", tokenize_chinese_chars=True,
max_length=None, pad_to_max_length=False, stride=0,
truncation_strategy='longest_first', add_special_tokens=True, **kwargs):
try:
from tokenizers import BertWordPieceTokenizer
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self._tokenizer = BertWordPieceTokenizer(
vocab_file,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
handle_chinese_chars=tokenize_chinese_chars,
lowercase=do_lower_case,
add_special_tokens=add_special_tokens
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
self._update_special_tokens()
if max_length is not None:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
except (AttributeError, ImportError) as e:
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.2.1`")
raise e
class DistilBertTokenizerFast(BertTokenizerFast):
r"""
Constructs a DistilBertTokenizer.
:class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
"""
vocab_files_names = tdb.VOCAB_FILES_NAMES
pretrained_vocab_files_map = tdb.PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = tdb.PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES