@@ -110,7 +110,7 @@ def train(gen_config):
110110 gen_loss_summary = tf .Summary ()
111111 gen_writer = tf .summary .FileWriter (gen_config .tensorboard_dir , sess .graph )
112112
113- while current_step < 100 :
113+ while current_step < 1000 :
114114 # Choose a bucket according to disc_data distribution. We pick a random number
115115 # in [0, 1] and use the corresponding interval in train_buckets_scale.
116116 random_number_01 = np .random .random_sample ()
@@ -239,6 +239,7 @@ def decoder(gen_config):
239239
240240 encoder_inputs , decoder_inputs , target_weights , batch_source_encoder , batch_source_decoder = \
241241 model .get_batch (train_set , bucket_id , gen_config .batch_size )
242+
242243
243244 _ , _ , out_logits = model .step (sess , encoder_inputs , decoder_inputs , target_weights , bucket_id ,
244245 forward_only = True )
@@ -255,14 +256,7 @@ def decoder(gen_config):
255256
256257 for seq in tokens_t :
257258 if data_utils .EOS_ID in seq :
258- '''
259- seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]]
260- seq是一维的,乍一看以为上面表达式把他当作二维处理,但实际上s=[1,2,3,4]
261- s[:3][:2]输出的是[1,2],也就是先截取[0:3]的数据,在截取[0:2]的数据
262- 而没有冒号,即s[3][2]是当做二维处理,在这边这么写是错的
263- '''
264-
265- #resps的shape为[[[vocab_size]],.......] 倒数第二层:decoder_size 最外一层:batch_size
259+
266260 resps .append (seq [:seq .index (data_utils .EOS_ID )][:gen_config .buckets [bucket_id ][1 ]])
267261 else :
268262 resps .append (seq [:gen_config .buckets [bucket_id ][1 ]])
@@ -272,7 +266,6 @@ def decoder(gen_config):
272266 answer_str = " " .join ([str (rev_vocab [an ]) for an in answer [:- 1 ]])
273267 disc_train_answer .write (answer_str )
274268 disc_train_answer .write ("\n " )
275-
276269 query_str = " " .join ([str (rev_vocab [qu ]) for qu in query ])
277270 disc_train_query .write (query_str )
278271 disc_train_query .write ("\n " )
0 commit comments