Skip to content

Commit de088a9

Browse files
authored
TensorFlow Transformer Part-3 (#10)
* intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
1 parent 63967b4 commit de088a9

11 files changed

Lines changed: 998 additions & 15 deletions

File tree

python/docs/sparkdl.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ Subpackages
66

77
.. toctree::
88

9+
sparkdl.estimators
910
sparkdl.graph
1011
sparkdl.image
12+
sparkdl.param
1113
sparkdl.transformers
1214
sparkdl.udf
1315
sparkdl.utils

python/sparkdl/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
# limitations under the License.
1414
#
1515

16+
from .graph.input import TFInputGraph
1617
from .image.imageIO import imageSchema, imageType, readImages
1718
from .transformers.keras_image import KerasImageFileTransformer
1819
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
1920
from .transformers.tf_image import TFImageTransformer
21+
from .transformers.tf_tensor import TFTransformer
2022
from .transformers.utils import imageInputPlaceholder
2123

24+
2225
__all__ = [
2326
'imageSchema', 'imageType', 'readImages',
24-
'TFImageTransformer',
25-
'DeepImagePredictor', 'DeepImageFeaturizer',
26-
'KerasImageFileTransformer',
27+
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
28+
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
2729
'imageInputPlaceholder']

python/sparkdl/graph/builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
4747
self.graph = graph or tf.Graph()
4848
self.sess = tf.Session(graph=self.graph)
4949
if using_keras:
50+
self.using_keras = True
5051
self.keras_prev_sess = K.get_session()
5152
else:
53+
self.using_keras = False
5254
self.keras_prev_sess = None
5355

5456
def __enter__(self):
55-
self.sess.as_default()
5657
self.sess.__enter__()
57-
if self.keras_prev_sess is not None:
58+
if self.using_keras:
5859
K.set_session(self.sess)
5960
return self
6061

6162
def __exit__(self, *args):
62-
if self.keras_prev_sess is not None:
63+
if self.using_keras:
6364
K.set_session(self.keras_prev_sess)
6465
self.sess.__exit__(*args)
6566

python/sparkdl/graph/input.py

Lines changed: 355 additions & 0 deletions
Large diffs are not rendered by default.

python/sparkdl/param/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sparkdl.param.shared_params import (
1717
keyword_only, HasInputCol, HasOutputCol, HasLabelCol,
1818
# TFTransformer Params
19-
HasInputMapping, HasOutputMapping, HasTFHParams,
19+
HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams,
2020
# Keras Estimator Params
2121
HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName)
2222
from sparkdl.param.converters import SparkDLTypeConverters

python/sparkdl/param/converters.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from pyspark.ml.param import TypeConverters
3232

33+
from sparkdl.graph.input import *
3334
import sparkdl.utils.keras_model as kmutil
3435

3536
__all__ = ['SparkDLTypeConverters']
@@ -52,6 +53,13 @@ def toTFGraph(value):
5253
raise TypeError("Could not convert %s to tf.Graph" % type(value))
5354
return value
5455

56+
@staticmethod
57+
def toTFInputGraph(value):
58+
if isinstance(value, TFInputGraph):
59+
return value
60+
else:
61+
raise TypeError("Could not convert %s to TFInputGraph" % type(value))
62+
5563
@staticmethod
5664
def asColumnToTensorNameMap(value):
5765
"""
@@ -167,7 +175,14 @@ def _check_is_tensor_name(_maybe_tnsr_name):
167175
raise TypeError(err_msg.format(type(_maybe_tnsr_name)))
168176

169177
# The check is taken from TensorFlow's NodeDef protocol buffer.
170-
# https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
178+
# Each input is "node:src_output" with "node" being a string name and
179+
# "src_output" indicating which output tensor to use from "node". If
180+
# "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
181+
# may optionally be followed by control inputs that have the format
182+
# "^node".
183+
# Reference:
184+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto
185+
# https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors
171186
try:
172187
_, src_idx = _maybe_tnsr_name.split(":")
173188
_ = int(src_idx)

python/sparkdl/param/shared_params.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
"""
2020
import textwrap
2121
from functools import wraps
22+
import six
2223

2324
from pyspark.ml.param import Param, Params, TypeConverters
2425

26+
from sparkdl.graph.input import TFInputGraph
2527
from sparkdl.param.converters import SparkDLTypeConverters
2628

2729
########################################################
@@ -196,8 +198,9 @@ class HasOutputMapping(Params):
196198
"""
197199
Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs
198200
"""
199-
outputMapping = Param(Params._dummy(), "outputMapping",
200-
"Mapping output :class:`tf.Operation` names to DataFrame column names",
201+
outputMapping = Param(Params._dummy(),
202+
"outputMapping",
203+
"Mapping output :class:`tf.Tensor` names to DataFrame column names",
201204
typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap)
202205

203206
def setOutputMapping(self, value):
@@ -211,8 +214,9 @@ class HasInputMapping(Params):
211214
"""
212215
Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs
213216
"""
214-
inputMapping = Param(Params._dummy(), "inputMapping",
215-
"Mapping input DataFrame column names to :class:`tf.Operation` names",
217+
inputMapping = Param(Params._dummy(),
218+
"inputMapping",
219+
"Mapping input DataFrame column names to :class:`tf.Tensor` names",
216220
typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap)
217221

218222
def setInputMapping(self, value):
@@ -222,6 +226,26 @@ def getInputMapping(self):
222226
return self.getOrDefault(self.inputMapping)
223227

224228

229+
class HasTFInputGraph(Params):
230+
"""
231+
Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph.
232+
"""
233+
tfInputGraph = Param(Params._dummy(),
234+
"tfInputGraph",
235+
"A serializable object derived from a TensorFlow computation graph",
236+
typeConverter=SparkDLTypeConverters.toTFInputGraph)
237+
238+
def __init__(self):
239+
super(HasTFInputGraph, self).__init__()
240+
self._setDefault(tfInputGraph=None)
241+
242+
def setTFInputGraph(self, value):
243+
return self._set(tfInputGraph=value)
244+
245+
def getTFInputGraph(self):
246+
return self.getOrDefault(self.tfInputGraph)
247+
248+
225249
class HasTFHParams(Params):
226250
"""
227251
Mixin for TensorFlow model hyper-parameters
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
from __future__ import absolute_import, division, print_function
16+
17+
import logging
18+
import tensorflow as tf
19+
from tensorflow.python.tools import optimize_for_inference_lib as infr_opt
20+
import tensorframes as tfs
21+
22+
from pyspark.ml import Transformer
23+
24+
import sparkdl.graph.utils as tfx
25+
from sparkdl.param import (keyword_only, HasInputMapping, HasOutputMapping,
26+
HasTFInputGraph, HasTFHParams)
27+
28+
__all__ = ['TFTransformer']
29+
30+
logger = logging.getLogger('sparkdl')
31+
32+
class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping):
33+
"""
34+
Applies the TensorFlow graph to the array column in DataFrame.
35+
36+
Restrictions of the current API:
37+
38+
We assume that
39+
- All the inputs of the graphs have a "minibatch" dimension (i.e. an unknown leading
40+
dimension) in the tensor shapes.
41+
- Input DataFrame has an array column where all elements have the same length
42+
- The transformer is expected to work on blocks of data at the same time.
43+
"""
44+
45+
@keyword_only
46+
def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None):
47+
"""
48+
__init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None)
49+
"""
50+
super(TFTransformer, self).__init__()
51+
kwargs = self._input_kwargs
52+
self.setParams(**kwargs)
53+
54+
@keyword_only
55+
def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None):
56+
"""
57+
setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None)
58+
"""
59+
super(TFTransformer, self).__init__()
60+
kwargs = self._input_kwargs
61+
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here
62+
return self._set(**kwargs)
63+
64+
def _optimize_for_inference(self):
65+
""" Optimize the graph for inference """
66+
gin = self.getTFInputGraph()
67+
input_mapping = self.getInputMapping()
68+
output_mapping = self.getOutputMapping()
69+
input_node_names = [tfx.op_name(tnsr_name) for _, tnsr_name in input_mapping]
70+
output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
71+
72+
# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type
73+
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def,
74+
input_node_names,
75+
output_node_names,
76+
# TODO: below is the place to change for
77+
# the `float64` data type issue.
78+
tf.float64.as_datatype_enum)
79+
return opt_gdef
80+
81+
def _transform(self, dataset):
82+
graph_def = self._optimize_for_inference()
83+
input_mapping = self.getInputMapping()
84+
output_mapping = self.getOutputMapping()
85+
86+
graph = tf.Graph()
87+
with tf.Session(graph=graph):
88+
analyzed_df = tfs.analyze(dataset)
89+
90+
out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
91+
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)
92+
93+
feed_dict = dict((tfx.op_name(tnsr_name, graph), col_name)
94+
for col_name, tnsr_name in input_mapping)
95+
fetches = [tfx.get_tensor(tnsr_op_name, graph) for tnsr_op_name in out_tnsr_op_names]
96+
97+
out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)
98+
99+
# We still have to rename output columns
100+
for tnsr_name, new_colname in output_mapping:
101+
old_colname = tfx.op_name(tnsr_name, graph)
102+
if old_colname != new_colname:
103+
out_df = out_df.withColumnRenamed(old_colname, new_colname)
104+
105+
return out_df

0 commit comments

Comments
 (0)