2828import sparkdl .utils .jvmapi as JVMAPI
2929import sparkdl .graph .utils as tfx
3030
31+ __all__ = ['TFImageTransformer' ]
32+
33+ IMAGE_INPUT_TENSOR_NAME = tfx .as_tensor_name (utils .IMAGE_INPUT_PLACEHOLDER_NAME )
34+ USER_GRAPH_NAMESPACE = 'given'
35+ NEW_OUTPUT_PREFIX = 'sdl_flattened'
36+
3137class TFImageTransformer (Transformer , HasInputCol , HasOutputCol , HasOutputMode ):
3238 """
3339 Applies the Tensorflow graph to the image column in DataFrame.
@@ -47,9 +53,6 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
4753 since a new session is created inside this transformer.
4854 """
4955
50- USER_GRAPH_NAMESPACE = 'given'
51- NEW_OUTPUT_PREFIX = 'sdl_flattened'
52-
5356 graph = Param (Params ._dummy (), "graph" , "A TensorFlow computation graph" ,
5457 typeConverter = SparkDLTypeConverters .toTFGraph )
5558 inputTensor = Param (Params ._dummy (), "inputTensor" ,
@@ -61,28 +64,28 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode):
6164
6265 @keyword_only
6366 def __init__ (self , inputCol = None , outputCol = None , graph = None ,
64- inputTensor = utils . IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor = None ,
67+ inputTensor = IMAGE_INPUT_TENSOR_NAME , outputTensor = None ,
6568 outputMode = "vector" ):
6669 """
6770 __init__(self, inputCol=None, outputCol=None, graph=None,
68- inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor=None,
71+ inputTensor=IMAGE_INPUT_TENSOR_NAME , outputTensor=None,
6972 outputMode="vector")
7073 """
7174 super (TFImageTransformer , self ).__init__ ()
72- self ._setDefault (inputTensor = utils .IMAGE_INPUT_PLACEHOLDER_NAME )
73- self ._setDefault (outputMode = "vector" )
7475 kwargs = self ._input_kwargs
7576 self .setParams (** kwargs )
7677
7778 @keyword_only
7879 def setParams (self , inputCol = None , outputCol = None , graph = None ,
79- inputTensor = utils . IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor = None ,
80+ inputTensor = IMAGE_INPUT_TENSOR_NAME , outputTensor = None ,
8081 outputMode = "vector" ):
8182 """
8283 setParams(self, inputCol=None, outputCol=None, graph=None,
83- inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME , outputTensor=None,
84+ inputTensor=IMAGE_INPUT_TENSOR_NAME , outputTensor=None,
8485 outputMode="vector")
8586 """
87+ self ._setDefault (inputTensor = IMAGE_INPUT_TENSOR_NAME )
88+ self ._setDefault (outputMode = "vector" )
8689 kwargs = self ._input_kwargs
8790 return self ._set (** kwargs )
8891
@@ -179,7 +182,7 @@ def _addReshapeLayers(self, tf_graph, dtype="uint8"):
179182 # Add on the original graph
180183 tf .import_graph_def (gdef , input_map = {input_tensor_name : image_reshaped_expanded },
181184 return_elements = [self .getOutputTensor ().name ],
182- name = self . USER_GRAPH_NAMESPACE )
185+ name = USER_GRAPH_NAMESPACE )
183186
184187 # Flatten the output for tensorframes
185188 output_node = g .get_tensor_by_name (self ._getOriginalOutputTensorName ())
@@ -198,10 +201,10 @@ def _stripGraph(self, tf_graph):
198201 return g
199202
200203 def _getOriginalOutputTensorName (self ):
201- return self . USER_GRAPH_NAMESPACE + '/' + self .getOutputTensor ().name
204+ return USER_GRAPH_NAMESPACE + '/' + self .getOutputTensor ().name
202205
203206 def _getFinalOutputTensorName (self ):
204- return self . NEW_OUTPUT_PREFIX + '_' + self .getOutputTensor ().name
207+ return NEW_OUTPUT_PREFIX + '_' + self .getOutputTensor ().name
205208
206209 def _getFinalOutputOpName (self ):
207210 return tfx .as_op_name (self ._getFinalOutputTensorName ())
0 commit comments