Skip to content

Commit 0034d92

Browse files
zhaoyingjunzhaoyingjun
authored andcommitted
更新编码问题和生成对话
1 parent b7d8bbd commit 0034d92

4 files changed

Lines changed: 35 additions & 34 deletions

File tree

lessonTen/seqGan chatbotv2.0/execute.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -306,36 +306,44 @@ def decoder_online(sess,gen_config, model, vocab,rev_vocab, inputs):
306306
# Get output logits for the sentence.
307307
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
308308

309-
# This is a greedy decoder - outputs are just argmaxes of output_logits.
310-
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits[0]]
311-
312309
# If there is an EOS symbol in outputs, cut them at that point.
313-
if data_utils.EOS_ID in outputs:
314-
outputs = outputs[:outputs.index(prepareData.EOS_ID)]
315-
316-
return " ".join([tf.compat.as_str(rev_vocab[output]) for output in outputs])
317-
318-
319-
310+
tokens = []
311+
resps = []
312+
for seq in output_logits:
313+
token = []
314+
for t in seq:
315+
token.append(int(np.argmax(t, axis=0)))
316+
tokens.append(token)
317+
tokens_t = []
318+
for col in range(len(tokens[0])):
319+
tokens_t.append([tokens[row][col] for row in range(len(tokens))])
320+
321+
for seq in tokens_t:
322+
if data_utils.EOS_ID in seq:
323+
resps.append(seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]])
324+
else:
325+
resps.append(seq[:gen_config.buckets[bucket_id][1]])
326+
for resp in resps:
327+
resq_str= " ".join([tf.compat.as_str(rev_vocab[output]) for output in resp])
328+
return resq_str
320329

321-
322330

323331
def main(_):
324332
# step_1 training gen model
325-
gen_pre_train()
333+
#gen_pre_train()
326334

327-
print("*****请注释掉本行代码,以及上行代码gen_pre_train(),下行代码sys.exit(0)然后继续执行execute.py********")
328-
sys.exit(0)
335+
#print("*****请注释掉本行代码,以及上行代码gen_pre_train(),下行代码sys.exit(0)然后继续执行execute.py********")
336+
#sys.exit(0)
329337
# step_2 gen training data for disc
330-
gen_disc()
338+
#gen_disc()
331339

332-
print("*****请注释掉本行代码,以及上行代码gen_disc(),下行代码sys.exit(0)然后继续执行execute.py********")
333-
sys.exit(0)
340+
#print("*****请注释掉本行代码,以及上行代码gen_disc(),下行代码sys.exit(0)然后继续执行execute.py********")
341+
#sys.exit(0)
334342

335343
# step_3 training disc model
336-
disc_pre_train()
337-
print("*****请注释掉本行代码,以及上行代码disc_pre_train(),下行代码sys.exit(0)然后继续执行execute.py********")
338-
sys.exit(0)
344+
#disc_pre_train()
345+
#print("*****请注释掉本行代码,以及上行代码disc_pre_train(),下行代码sys.exit(0)然后继续执行execute.py********")
346+
#sys.exit(0)
339347
# step_4 training al model
340348
al_train()
341349

lessonTen/seqGan chatbotv2.0/gen/generator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def train(gen_config):
110110
gen_loss_summary = tf.Summary()
111111
gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir, sess.graph)
112112

113-
while current_step<100:
113+
while current_step<1000:
114114
# Choose a bucket according to disc_data distribution. We pick a random number
115115
# in [0, 1] and use the corresponding interval in train_buckets_scale.
116116
random_number_01 = np.random.random_sample()
@@ -239,6 +239,7 @@ def decoder(gen_config):
239239

240240
encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = \
241241
model.get_batch(train_set, bucket_id, gen_config.batch_size)
242+
242243

243244
_, _, out_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
244245
forward_only=True)
@@ -255,14 +256,7 @@ def decoder(gen_config):
255256

256257
for seq in tokens_t:
257258
if data_utils.EOS_ID in seq:
258-
'''
259-
seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]]
260-
seq是一维的,乍一看以为上面表达式把他当作二维处理,但实际上s=[1,2,3,4]
261-
s[:3][:2]输出的是[1,2],也就是先截取[0:3]的数据,在截取[0:2]的数据
262-
而没有冒号,即s[3][2]是当做二维处理,在这边这么写是错的
263-
'''
264-
265-
#resps的shape为[[[vocab_size]],.......] 倒数第二层:decoder_size 最外一层:batch_size
259+
266260
resps.append(seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]])
267261
else:
268262
resps.append(seq[:gen_config.buckets[bucket_id][1]])
@@ -272,7 +266,6 @@ def decoder(gen_config):
272266
answer_str = " ".join([str(rev_vocab[an]) for an in answer[:-1]])
273267
disc_train_answer.write(answer_str)
274268
disc_train_answer.write("\n")
275-
276269
query_str = " ".join([str(rev_vocab[qu]) for qu in query])
277270
disc_train_query.write(query_str)
278271
disc_train_query.write("\n")

lessonTen/seqGan chatbotv2.0/utils/conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ class disc_config(object):
55
batch_size = 64
66
lr = 0.2
77
lr_decay = 0.9
8-
vocab_size = 25000
8+
vocab_size = 2500
99
embed_dim = 64
1010
steps_per_checkpoint = 20
1111
#hidden_neural_size = 128
@@ -36,7 +36,7 @@ class gen_config(object):
3636
batch_size = 64
3737
emb_dim = 64
3838
num_layers = 2
39-
vocab_size = 25000
39+
vocab_size = 2500
4040
train_dir = "./gen_data/"
4141
name_model = "st_model"
4242
tensorboard_dir = "./tensorboard/gen_log/"
@@ -59,7 +59,7 @@ class GSTConfig(object):
5959
batch_size = 256
6060
emb_dim = 1024
6161
num_layers = 2
62-
vocab_size = 25000
62+
vocab_size = 2500
6363
train_dir = "./gst_data/"
6464
name_model = "st_model"
6565
tensorboard_dir = "./tensorboard/gst_log/"

lessonTen/seqGan chatbotv2.0/utils/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def initialize_vocabulary(vocabulary_path):
119119
"""
120120
if gfile.Exists(vocabulary_path):
121121
rev_vocab = []
122-
with gfile.GFile(vocabulary_path, mode="rb") as f:
122+
with open(vocabulary_path, mode="r") as f:
123123
rev_vocab.extend(f.readlines())
124124
rev_vocab = [line.strip() for line in rev_vocab]
125125
vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])

0 commit comments

Comments
 (0)