Skip to content

Commit 196244e

Browse files
committed
enable all tests
1 parent 6b22eed commit 196244e

2 files changed

Lines changed: 48 additions & 92 deletions

File tree

python/sparkdl/graph/input.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def _build_impl(self, feed_names, fetch_names):
225225
gin = TFInputGraph._new_obj_internal()
226226
assert (feed_names is None) == (fetch_names is None)
227227
must_have_sig_def = fetch_names is None
228-
print('builder-session', repr(self.sess))
229228
# NOTE(phi-dbq): both have to be set to default
230229
with self.sess.as_default(), self.graph.as_default():
231230
_ginfo = self.import_graph_fn(self.sess)
@@ -249,47 +248,3 @@ def build(self, feed_names=None, fetch_names=None):
249248
if self._should_clean:
250249
self.sess.close()
251250
return gin
252-
253-
# def the_rest(input_mapping, output_mapping):
254-
# graph = tf.Graph()
255-
# with tf.Session(graph=graph) as sess:
256-
# # Append feeds and input mapping
257-
# _input_mapping = {}
258-
# if isinstance(input_mapping, dict):
259-
# input_mapping = input_mapping.items()
260-
# for input_colname, tnsr_or_sig in input_mapping:
261-
# if sig_def:
262-
# tnsr = sig_def.inputs[tnsr_or_sig].name
263-
# else:
264-
# tnsr = tnsr_or_sig
265-
# _input_mapping[input_colname] = tfx.op_name(graph, tnsr)
266-
# input_mapping = _input_mapping
267-
268-
# # Append fetches and output mapping
269-
# fetches = []
270-
# _output_mapping = {}
271-
# # By default the output columns will have the name of their
272-
# # corresponding `tf.Graph` operation names.
273-
# # We have to convert them to the user specified output names
274-
# if isinstance(output_mapping, dict):
275-
# output_mapping = output_mapping.items()
276-
# for tnsr_or_sig, requested_colname in output_mapping:
277-
# if sig_def:
278-
# tnsr = sig_def.outputs[tnsr_or_sig].name
279-
# else:
280-
# tnsr = tnsr_or_sig
281-
# fetches.append(tfx.get_tensor(graph, tnsr))
282-
# tf_output_colname = tfx.op_name(graph, tnsr)
283-
# # NOTE(phi-dbq): put the check here as it will be the entry point to construct
284-
# # a `TFInputGraph` object.
285-
# assert tf_output_colname not in _output_mapping, \
286-
# "operation {} has multiple output tensors and ".format(tf_output_colname) + \
287-
# "at least two of them are used in the output DataFrame. " + \
288-
# "Operation names are used to name columns which leads to conflicts. " + \
289-
# "You can apply `tf.identity` ops to each to avoid name conflicts."
290-
# _output_mapping[tf_output_colname] = requested_colname
291-
# output_mapping = _output_mapping
292-
293-
# gdef = tfx.strip_and_freeze_until(fetches, graph, sess)
294-
295-
# return TFInputGraph(gdef), input_mapping, output_mapping

python/tests/transformers/tf_tensor_test.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from pyspark.sql.types import Row
2929

30+
from sparkdl.graph.builder import IsolatedSession
3031
from sparkdl.graph.input import *
3132
import sparkdl.graph.utils as tfx
3233
from sparkdl.transformers.tf_tensor import TFTransformer
@@ -60,6 +61,15 @@ def setUp(self):
6061
def tearDown(self):
6162
shutil.rmtree(self.model_output_root, ignore_errors=True)
6263

64+
def _build_default_session_tests(self, sess):
65+
gin = TFInputGraph.fromGraph(
66+
sess.graph, sess, self.feed_names, self.fetch_names)
67+
self.build_standard_transformers(sess, gin)
68+
69+
gin = TFInputGraph.fromGraphDef(
70+
sess.graph.as_graph_def(), self.feed_names, self.fetch_names)
71+
self.build_standard_transformers(sess, gin)
72+
6373
def build_standard_transformers(self, sess, tf_input_graph):
6474
def _add_transformer(imap, omap):
6575
trnsfmr = TFTransformer(
@@ -113,7 +123,7 @@ def _run_test_in_tf_session(self):
113123

114124
# Build the TensorFlow graph
115125
graph = tf.Graph()
116-
with tf.Session(graph=graph) as sess:
126+
with tf.Session(graph=graph) as sess, graph.as_default():
117127
# Build test graph and transformers from here
118128
yield sess
119129

@@ -147,8 +157,7 @@ def _run_test_in_tf_session(self):
147157
_results.append(np.ravel(curr_res))
148158
out_tgt = np.hstack(_results)
149159

150-
self.assertTrue(np.allclose(out_ref, out_tgt),
151-
msg=repr(transfomer))
160+
self.assertTrue(np.allclose(out_ref, out_tgt), msg=repr(transfomer))
152161

153162

154163
def test_build_from_tf_graph(self):
@@ -159,13 +168,7 @@ def test_build_from_tf_graph(self):
159168
_ = tf.reduce_mean(x, axis=1, name=self.output_op_name)
160169
# End building graph
161170

162-
# Begin building transformers
163-
self.build_standard_transformers(
164-
sess, TFInputGraph.fromGraph(sess.graph, sess, self.feed_names, self.fetch_names))
165-
gdef = sess.graph.as_graph_def()
166-
self.build_standard_transformers(
167-
sess, TFInputGraph.fromGraphDef(gdef, self.feed_names, self.fetch_names))
168-
# End building transformers
171+
self._build_default_session_tests(sess)
169172

170173

171174
def test_build_from_saved_model(self):
@@ -224,6 +227,10 @@ def test_build_from_saved_model(self):
224227
saved_model_dir, serving_tag, self.feed_names, self.fetch_names)
225228
self.build_standard_transformers(sess, gin)
226229

230+
gin = TFInputGraph.fromGraph(
231+
sess.graph, sess, self.feed_names, self.fetch_names)
232+
self.build_standard_transformers(sess, gin)
233+
227234

228235
def test_build_from_checkpoint(self):
229236
""" Build TFTransformer from a model checkpoint """
@@ -285,6 +292,10 @@ def test_build_from_checkpoint(self):
285292
gin = TFInputGraph.fromCheckpoint(model_ckpt_dir, self.feed_names, self.fetch_names)
286293
self.build_standard_transformers(sess, gin)
287294

295+
gin = TFInputGraph.fromGraph(
296+
sess.graph, sess, self.feed_names, self.fetch_names)
297+
self.build_standard_transformers(sess, gin)
298+
288299

289300
def test_multi_io(self):
290301
""" Build TFTransformer with multiple I/O tensors """
@@ -300,41 +311,31 @@ def test_multi_io(self):
300311
z = tf.reduce_mean(xs[i], axis=1, name=tnsr_op_name)
301312
zs.append(z)
302313

303-
gin = TFInputGraph.fromGraph(
304-
sess.graph, sess, self.feed_names, self.fetch_names)
305-
self.build_standard_transformers(sess, gin)
306-
307-
gin = TFInputGraph.fromGraphDef(
308-
sess.graph.as_graph_def(), self.feed_names, self.fetch_names)
309-
self.build_standard_transformers(sess, gin)
310-
314+
self._build_default_session_tests(sess)
315+
316+
317+
def test_mixed_keras_graph(self):
318+
""" Build mixed keras graph """
319+
with IsolatedSession(using_keras=True) as issn:
320+
tnsr_in = tf.placeholder(
321+
tf.double, shape=[None, self.vec_size], name=self.input_op_name)
322+
inp = tf.expand_dims(tnsr_in, axis=2)
323+
# Keras layers does not take tf.double
324+
inp = tf.cast(inp, tf.float32)
325+
conv = Conv1D(filters=4, kernel_size=2)(inp)
326+
pool = MaxPool1D(pool_size=2)(conv)
327+
flat = Flatten()(pool)
328+
dense = Dense(1)(flat)
329+
# We must keep the leading dimension of the output
330+
redsum = tf.reduce_sum(dense, axis=1)
331+
tnsr_out = tf.cast(redsum, tf.double, name=self.output_op_name)
332+
333+
# Initialize the variables
334+
init_op = tf.global_variables_initializer()
335+
issn.run(init_op)
336+
# We could train the model ... but skip it here
337+
gfn = issn.asGraphFunction([tnsr_in], [tnsr_out])
311338

312-
# def test_mixed_keras_graph(self):
313-
# # Build the graph: the output should have the same leading/batch dimension
314-
# with IsolatedSession(using_keras=True) as issn:
315-
# tnsr_in = tf.placeholder(
316-
# tf.double, shape=[None, self.vec_size], name=self.input_op_name)
317-
# inp = tf.expand_dims(tnsr_in, axis=2)
318-
# # Keras layers does not take tf.double
319-
# inp = tf.cast(inp, tf.float32)
320-
# conv = Conv1D(filters=4, kernel_size=2)(inp)
321-
# pool = MaxPool1D(pool_size=2)(conv)
322-
# flat = Flatten()(pool)
323-
# dense = Dense(1)(flat)
324-
# # We must keep the leading dimension of the output
325-
# redsum = tf.reduce_sum(dense, axis=1)
326-
# tnsr_out = tf.cast(redsum, tf.double, name=self.output_op_name)
327-
328-
# # Initialize the variables
329-
# init_op = tf.global_variables_initializer()
330-
# issn.run(init_op)
331-
# # We could train the model ... but skip it here
332-
# gfn = issn.asGraphFunction([tnsr_in], [tnsr_out])
333-
334-
# with self.run_test_in_tf_session() as sess:
335-
# tf.import_graph_def(gfn.graph_def, name='')
336-
337-
# self.build_standard_transformers(sess, sess.graph)
338-
# self.build_standard_transformers(sess, TFInputGraphBuilder.fromGraph(sess.graph))
339-
# self.build_standard_transformers(sess, gfn.graph_def)
340-
# self.build_standard_transformers(sess, TFInputGraphBuilder.fromGraphDef(gfn.graph_def))
339+
with self._run_test_in_tf_session() as sess:
340+
tf.import_graph_def(gfn.graph_def, name='')
341+
self._build_default_session_tests(sess)

0 commit comments

Comments
 (0)