-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathbert_main.py
More file actions
executable file
·334 lines (302 loc) · 14.3 KB
/
bert_main.py
File metadata and controls
executable file
·334 lines (302 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
#!/usr/bin/env python
# encoding=utf-8
'''
@Time : 2020/06/14 17:45:13
@Author : zhiyang.zzy
@Contact : zhiyangchou@gmail.com
@Desc :
1. 随机插入mask,使用bert来生成 mask 的内容,来丰富句子
2. 随机将某些词语mask,使用bert来生成 mask 的内容。
- 使用贪心算法,每次最优。
- beam search方法,每次保留最优的前n个,最多num_beams个句子。(注意句子数据大于num_beams个时候,剔除概率最低的,防止内存溢出)。
'''
# here put the import lib
import nltk
import tensorflow as tf
# from transformers import *
import heapq
from tensorflow.python.ops.gen_math_ops import mod
from zhon.hanzi import punctuation
import string
import jieba
import numpy as np
import sys
from util import read_file
from bert_modify import modeling as modeling, tokenization, optimization
from collections import defaultdict
print(tf.__version__)
punc = string.punctuation + punctuation
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions):
"""Get loss and log probs for the masked LM."""
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
class BertAugmentor(object):
def __init__(self, model_dir, beam_size=5):
self.beam_size = beam_size # 每个带mask的句子最多生成 beam_size 个。
# bert的配置文件
self.bert_config_file = model_dir + 'bert_config.json'
self.init_checkpoint = model_dir + 'bert_model.ckpt'
# init_checkpoint = model_dir
self.bert_vocab_file = model_dir + 'vocab.txt'
self.bert_config = modeling.BertConfig.from_json_file(
self.bert_config_file)
# token策略,由于是中文,使用了token分割,同时对于数字和英文使用char分割。
self.token = tokenization.CharTokenizer(vocab_file=self.bert_vocab_file)
self.mask_token = "[MASK]"
self.mask_id = self.token.convert_tokens_to_ids([self.mask_token])[0]
self.cls_token = "[CLS]"
self.cls_id = self.token.convert_tokens_to_ids([self.cls_token])[0]
self.sep_token = "[SEP]"
self.sep_id = self.token.convert_tokens_to_ids([self.sep_token])[0]
# 构图
self.build()
# sess init
self.build_sess()
def __del__(self):
# 析构函数
self.close_sess()
def build(self):
# placeholder
self.input_ids = tf.placeholder(
tf.int32, shape=[None, None], name='input_ids')
self.input_mask = tf.placeholder(
tf.int32, shape=[None, None], name='input_masks')
self.segment_ids = tf.placeholder(
tf.int32, shape=[None, None], name='segment_ids')
self.masked_lm_positions = tf.placeholder(
tf.int32, shape=[None, None], name='masked_lm_positions')
# 初始化BERT
self.model = modeling.BertModel(
config=self.bert_config,
is_training=False,
input_ids=self.input_ids,
input_mask=self.input_mask,
token_type_ids=self.segment_ids,
use_one_hot_embeddings=False)
self.masked_logits = get_masked_lm_output(
self.bert_config, self.model.get_sequence_output(), self.model.get_embedding_table(),
self.masked_lm_positions)
self.predict_prob = tf.nn.softmax(self.masked_logits, axis=-1)
# 加载bert模型
tvars = tf.trainable_variables()
(assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, self.init_checkpoint)
tf.train.init_from_checkpoint(self.init_checkpoint, assignment)
def build_sess(self):
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
def close_sess(self):
# self.sess.close()
pass
def predict_single_mask(self, word_ids:list, mask_index:int, prob:float=None):
"""输入一个句子token id list,对其中第mask_index个的mask的可能内容,返回 self.beam_size 个候选词语,以及prob"""
word_ids_out = []
word_mask = [1] * len(word_ids)
word_segment_ids = [0] * len(word_ids)
fd = {self.input_ids: [word_ids], self.input_mask: [word_mask], self.segment_ids: [
word_segment_ids], self.masked_lm_positions: [[mask_index]]}
mask_probs = self.sess.run(self.predict_prob, feed_dict=fd)
for mask_prob in mask_probs:
mask_prob = mask_prob.tolist()
max_num_index_list = map(mask_prob.index, heapq.nlargest(self.beam_size, mask_prob))
for i in max_num_index_list:
if prob and mask_prob[i] < prob:
continue
cur_word_ids = word_ids.copy()
cur_word_ids[mask_index] = i
word_ids_out.append([cur_word_ids, mask_prob[i]])
return word_ids_out
def predict_batch_mask(self, query_ids:list, mask_indexes:int, prob:float=0.5):
"""输入多个token id list,对其中第mask_index个的mask的可能内容,返回 self.beam_size 个候选词语,以及prob
word_ids: [word_ids1:list, ], shape=[batch, query_lenght]
mask_indexes: query要预测的mask_id, [[mask_id], ...], shape=[batch, 1, 1]
"""
word_ids_out = []
word_mask = [[1] * len(x) for x in query_ids]
word_segment_ids = [[1] * len(x) for x in query_ids]
fd = {self.input_ids: query_ids, self.input_mask: word_mask, self.segment_ids:
word_segment_ids, self.masked_lm_positions: mask_indexes}
mask_probs = self.sess.run(self.predict_prob, feed_dict=fd)
for mask_prob, word_ids_, mask_index in zip(mask_probs, query_ids, mask_indexes):
# each query of batch
cur_out = []
mask_prob = mask_prob.tolist()
max_num_index_list = map(mask_prob.index, heapq.nlargest(self.n_best, mask_prob))
for i in max_num_index_list:
cur_word_ids = word_ids_.copy()
cur_word_ids[mask_index[0]] = i
cur_out.append([cur_word_ids, mask_prob[i]])
word_ids_out.append(cur_out)
return word_ids_out
def gen_sen(self, word_ids:list, indexes:list):
"""
输入是一个word id list, 其中包含mask,对mask生产对应的词语。
因为每个query的mask数量不一致,预测测试不一致,需要单独预测
"""
out_arr = []
for i, index_ in enumerate(indexes):
if i == 0:
out_arr = self.predict_single_mask(word_ids, index_)
else:
tmp_arr = out_arr.copy()
out_arr = []
for word_ids_, prob in tmp_arr:
cur_arr = self.predict_single_mask(word_ids_, index_)
cur_arr = [[x[0], x[1] * prob] for x in cur_arr]
out_arr.extend(cur_arr)
# 筛选前beam size个
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)[:self.beam_size]
for i, (each, _) in enumerate(out_arr):
query_ = [self.token.id2vocab[x] for x in each]
out_arr[i][0] = query_
return out_arr
def word_insert(self, query):
"""随机将某些词语mask,使用bert来生成 mask 的内容。
max_query: 所有query最多生成的个数。
"""
out_arr = []
seg_list = jieba.cut(query, cut_all=False)
# 随机选择非停用词mask。
i, index_arr = 1, [1]
for each in seg_list:
i += len(each)
index_arr.append(i)
# query转id
split_tokens = self.token.tokenize(query)
word_ids = self.token.convert_tokens_to_ids(split_tokens)
word_ids.insert(0, self.cls_id)
word_ids.append(self.sep_id)
word_ids_arr, word_index_arr = [], []
# 随机insert n 个字符, 1<=n<=3
for index_ in index_arr:
insert_num = np.random.randint(1, 4)
word_ids_ = word_ids.copy()
word_index = []
for i in range(insert_num):
word_ids_.insert(index_, self.mask_id)
word_index.append(index_ + i)
word_ids_arr.append(word_ids_)
word_index_arr.append(word_index)
for word_ids, word_index in zip(word_ids_arr, word_index_arr):
arr_ = self.gen_sen(word_ids, indexes=word_index)
out_arr.extend(arr_)
pass
# 这个是所有生成的句子中,筛选出前 beam size 个。
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)
out_arr = ["".join(x[0][1:-1]) for x in out_arr[:self.beam_size]]
return out_arr
def word_replace(self, query):
"""随机将某些词语mask,使用bert来生成 mask 的内容。"""
out_arr = []
seg_list = jieba.cut(query, cut_all=False)
# 随机选择非停用词mask。
i, index_map = 1, {}
for each in seg_list:
index_map[i] = len(each)
i += len(each)
# query转id
split_tokens = self.token.tokenize(query)
word_ids = self.token.convert_tokens_to_ids(split_tokens)
word_ids.insert(0, self.cls_id)
word_ids.append(self.sep_id)
word_ids_arr, word_index_arr = [], []
# 依次mask词语,
for index_, word_len in index_map.items():
word_ids_ = word_ids.copy()
word_index = []
for i in range(word_len):
word_ids_[index_ + i] = self.mask_id
word_index.append(index_ + i)
word_ids_arr.append(word_ids_)
word_index_arr.append(word_index)
for word_ids, word_index in zip(word_ids_arr, word_index_arr):
arr_ = self.gen_sen(word_ids, indexes=word_index)
out_arr.extend(arr_)
pass
out_arr = sorted(out_arr, key=lambda x: x[1], reverse=True)
out_arr = ["".join(x[0][1:-1]) for x in out_arr[:self.beam_size]]
return out_arr
def insert_word2queries(self, queries:list, beam_size=10):
self.beam_size = beam_size
out_map = defaultdict(list)
for query in queries:
out_map[query] = self.word_insert(query)
return out_map
def replace_word2queries(self, queries:list, beam_size=10):
self.beam_size = beam_size
out_map = defaultdict(list)
for query in queries:
out_map[query] = self.word_replace(query)
return out_map
def predict(self, query_arr, beam_size=None):
"""
query_arr: ["w1", "w2", "[MASK]", ...], shape=[word_len]
每个query_arr, 都会返回beam_size个
"""
self.beam_size = beam_size if beam_size else self.beam_size
word_ids, indexes = self.token.convert_tokens_to_ids(query_arr), [x[0] for x in filter(lambda x: x[1] == self.mask_token, enumerate(query_arr))]
out_queries = self.gen_sen(word_ids, indexes)
out_queries = [["".join(x[0]), x[1]] for x in out_queries]
return out_queries
def augment(file_, model_dir=None):
"""
file_: 输入文件,每行是一个query
model_dir: bert 预训练模型地址,中文bert下载链接:https://github.com/InsaneLife/ChineseNLPCorpus#%E9%A2%84%E8%AE%AD%E7%BB%83%E8%AF%8D%E5%90%91%E9%87%8For%E6%A8%A1%E5%9E%8B
"""
if not model_dir:
raise Exception("must feed params:[model_dir]")
# query输入文件,每个query一行
queries = read_file(file_)
mask_model = BertAugmentor(model_dir)
# 随机替换:通过随机mask掉词语,预测可能的值。
replace_result = mask_model.replace_word2queries(queries, beam_size=20)
with open(file_ + ".augment.bert_replace", 'w', encoding='utf-8') as out:
for query, v in replace_result.items():
out.write("{}\t{}\n".format(query, ';'.join(v)))
# 随机插入:通过随机插入mask,预测可能的词语
insert_result = mask_model.insert_word2queries(queries, beam_size=20)
print("Augmentor's result:", insert_result)
# 写出到文件
with open(file_ + ".augment.bert_insert", 'w', encoding='utf-8') as out:
for query, v in insert_result.items():
out.write("{}\t{}\n".format(query, ';'.join(v)))
# bert 预测 mask
out_queries = mask_model.predict(["[MASK]", "[MASK]", "卖", "账", "号", "吗"], beam_size=5)
pass
if __name__ == "__main__":
# bert 模型下载地址,中文bert下载链接:https://github.com/InsaneLife/ChineseNLPCorpus#%E9%A2%84%E8%AE%AD%E7%BB%83%E8%AF%8D%E5%90%91%E9%87%8For%E6%A8%A1%E5%9E%8B
model_dir = '/Volumes/HddData/ProjectData/NLP/bert/chinese_L-12_H-768_A-12/'
# query输入文件,每个query一行
augment("data/input", model_dir=model_dir)