2727
2828from pyspark .sql .types import Row
2929
30+ from sparkdl .graph .builder import IsolatedSession
3031from sparkdl .graph .input import *
3132import sparkdl .graph .utils as tfx
3233from sparkdl .transformers .tf_tensor import TFTransformer
@@ -60,6 +61,15 @@ def setUp(self):
6061 def tearDown (self ):
6162 shutil .rmtree (self .model_output_root , ignore_errors = True )
6263
64+ def _build_default_session_tests (self , sess ):
65+ gin = TFInputGraph .fromGraph (
66+ sess .graph , sess , self .feed_names , self .fetch_names )
67+ self .build_standard_transformers (sess , gin )
68+
69+ gin = TFInputGraph .fromGraphDef (
70+ sess .graph .as_graph_def (), self .feed_names , self .fetch_names )
71+ self .build_standard_transformers (sess , gin )
72+
6373 def build_standard_transformers (self , sess , tf_input_graph ):
6474 def _add_transformer (imap , omap ):
6575 trnsfmr = TFTransformer (
@@ -113,7 +123,7 @@ def _run_test_in_tf_session(self):
113123
114124 # Build the TensorFlow graph
115125 graph = tf .Graph ()
116- with tf .Session (graph = graph ) as sess :
126+ with tf .Session (graph = graph ) as sess , graph . as_default () :
117127 # Build test graph and transformers from here
118128 yield sess
119129
@@ -147,8 +157,7 @@ def _run_test_in_tf_session(self):
147157 _results .append (np .ravel (curr_res ))
148158 out_tgt = np .hstack (_results )
149159
150- self .assertTrue (np .allclose (out_ref , out_tgt ),
151- msg = repr (transfomer ))
160+ self .assertTrue (np .allclose (out_ref , out_tgt ), msg = repr (transfomer ))
152161
153162
154163 def test_build_from_tf_graph (self ):
@@ -159,13 +168,7 @@ def test_build_from_tf_graph(self):
159168 _ = tf .reduce_mean (x , axis = 1 , name = self .output_op_name )
160169 # End building graph
161170
162- # Begin building transformers
163- self .build_standard_transformers (
164- sess , TFInputGraph .fromGraph (sess .graph , sess , self .feed_names , self .fetch_names ))
165- gdef = sess .graph .as_graph_def ()
166- self .build_standard_transformers (
167- sess , TFInputGraph .fromGraphDef (gdef , self .feed_names , self .fetch_names ))
168- # End building transformers
171+ self ._build_default_session_tests (sess )
169172
170173
171174 def test_build_from_saved_model (self ):
@@ -224,6 +227,10 @@ def test_build_from_saved_model(self):
224227 saved_model_dir , serving_tag , self .feed_names , self .fetch_names )
225228 self .build_standard_transformers (sess , gin )
226229
230+ gin = TFInputGraph .fromGraph (
231+ sess .graph , sess , self .feed_names , self .fetch_names )
232+ self .build_standard_transformers (sess , gin )
233+
227234
228235 def test_build_from_checkpoint (self ):
229236 """ Build TFTransformer from a model checkpoint """
@@ -285,6 +292,10 @@ def test_build_from_checkpoint(self):
285292 gin = TFInputGraph .fromCheckpoint (model_ckpt_dir , self .feed_names , self .fetch_names )
286293 self .build_standard_transformers (sess , gin )
287294
295+ gin = TFInputGraph .fromGraph (
296+ sess .graph , sess , self .feed_names , self .fetch_names )
297+ self .build_standard_transformers (sess , gin )
298+
288299
289300 def test_multi_io (self ):
290301 """ Build TFTransformer with multiple I/O tensors """
@@ -300,41 +311,31 @@ def test_multi_io(self):
300311 z = tf .reduce_mean (xs [i ], axis = 1 , name = tnsr_op_name )
301312 zs .append (z )
302313
303- gin = TFInputGraph .fromGraph (
304- sess .graph , sess , self .feed_names , self .fetch_names )
305- self .build_standard_transformers (sess , gin )
306-
307- gin = TFInputGraph .fromGraphDef (
308- sess .graph .as_graph_def (), self .feed_names , self .fetch_names )
309- self .build_standard_transformers (sess , gin )
310-
314+ self ._build_default_session_tests (sess )
315+
316+
317+ def test_mixed_keras_graph (self ):
318+ """ Build mixed keras graph """
319+ with IsolatedSession (using_keras = True ) as issn :
320+ tnsr_in = tf .placeholder (
321+ tf .double , shape = [None , self .vec_size ], name = self .input_op_name )
322+ inp = tf .expand_dims (tnsr_in , axis = 2 )
323+ # Keras layers does not take tf.double
324+ inp = tf .cast (inp , tf .float32 )
325+ conv = Conv1D (filters = 4 , kernel_size = 2 )(inp )
326+ pool = MaxPool1D (pool_size = 2 )(conv )
327+ flat = Flatten ()(pool )
328+ dense = Dense (1 )(flat )
329+ # We must keep the leading dimension of the output
330+ redsum = tf .reduce_sum (dense , axis = 1 )
331+ tnsr_out = tf .cast (redsum , tf .double , name = self .output_op_name )
332+
333+ # Initialize the variables
334+ init_op = tf .global_variables_initializer ()
335+ issn .run (init_op )
336+ # We could train the model ... but skip it here
337+ gfn = issn .asGraphFunction ([tnsr_in ], [tnsr_out ])
311338
312- # def test_mixed_keras_graph(self):
313- # # Build the graph: the output should have the same leading/batch dimension
314- # with IsolatedSession(using_keras=True) as issn:
315- # tnsr_in = tf.placeholder(
316- # tf.double, shape=[None, self.vec_size], name=self.input_op_name)
317- # inp = tf.expand_dims(tnsr_in, axis=2)
318- # # Keras layers does not take tf.double
319- # inp = tf.cast(inp, tf.float32)
320- # conv = Conv1D(filters=4, kernel_size=2)(inp)
321- # pool = MaxPool1D(pool_size=2)(conv)
322- # flat = Flatten()(pool)
323- # dense = Dense(1)(flat)
324- # # We must keep the leading dimension of the output
325- # redsum = tf.reduce_sum(dense, axis=1)
326- # tnsr_out = tf.cast(redsum, tf.double, name=self.output_op_name)
327-
328- # # Initialize the variables
329- # init_op = tf.global_variables_initializer()
330- # issn.run(init_op)
331- # # We could train the model ... but skip it here
332- # gfn = issn.asGraphFunction([tnsr_in], [tnsr_out])
333-
334- # with self.run_test_in_tf_session() as sess:
335- # tf.import_graph_def(gfn.graph_def, name='')
336-
337- # self.build_standard_transformers(sess, sess.graph)
338- # self.build_standard_transformers(sess, TFInputGraphBuilder.fromGraph(sess.graph))
339- # self.build_standard_transformers(sess, gfn.graph_def)
340- # self.build_standard_transformers(sess, TFInputGraphBuilder.fromGraphDef(gfn.graph_def))
339+ with self ._run_test_in_tf_session () as sess :
340+ tf .import_graph_def (gfn .graph_def , name = '' )
341+ self ._build_default_session_tests (sess )
0 commit comments