Skip to content

Commit cf856db

Browse files
authored
TensorFlow Transformer Part-2 (#9)
* update utils * tests * fix style Using the following YAPF style ======================================================== based_on_style = pep8 ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT=True BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=False COLUMN_LIMIT=100 SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET=False SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED=True SPLIT_BEFORE_FIRST_ARGUMENT=False SPLIT_BEFORE_NAMED_ASSIGNS=False SPLIT_PENALTY_AFTER_OPENING_BRACKET=30 USE_TABS=False ======================================================== * refactoring tfx API * test refactoring * PR comments 1. docs in graph/utils.py * (wip) utils test * a few more tests for utils * test update cont'd * PR comments * PR comments * PR comments * TensorFlow Transformer Part-3 (#10) * intro: TFInputGraph * tests * Merge branch 'tf-transformer-part1' into tf-transformer-part3 * and so there is no helper classes * and into more pieces * class & docs * update docs * refactoring tfx API * update tfx utils usage * one way to build these tests * tests refactored * test cases in a single class THis will make things easier when we want to extend other base class functions. * shuffle things around Signed-off-by: Philip Yang <philip.yang@databricks.com> * docs mostly * yapf'd * consolidate tempdir creation * (wip) PR comments * more tests * change test generator module name * TFTransformer Part-3 Test Refactor (#14) * profiling * tests * renamed test * removed original tests * removed the profiler utils * fixes indents * imports * added some tests * added test * fix test * one more test * PR comments * TensorFlow Transformer Part-4 (#11) * flat param API impl * support input graph scenarios * (WIP) new interface implementation * docs and cleanup * using tensorflow API instead of our utilities * automatic type conversion * cleanup * PR comments 1. Move `InputGraph` to its module. * (WIP) address comments * (WIP) respond to PR comments * test refactor * (wip) consolidating params * rebase upstream * import params fix * (wip) TFInputGraph impl * (wip) moving to new API * (wip) enable saved_model tests * (wip) enable checkpoint test * (wip) enable multiple tensor tests * enable all tests * optimize graph for inference * allows setting TFInputGraph * utilize test_input_graph for transformer tests * enable all tests Signed-off-by: Philip Yang <philip.yang@databricks.com> * input graph * docs * tensor tests * tensor test update * TFTransformer Part-4 Test Refactor (#15) * adding new tests * remove original test design * cleanup * deleting original testing ideas * PR comments
1 parent a8531ec commit cf856db

19 files changed

Lines changed: 1297 additions & 119 deletions

python/docs/sparkdl.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ Subpackages
66

77
.. toctree::
88

9+
sparkdl.estimators
910
sparkdl.graph
1011
sparkdl.image
12+
sparkdl.param
1113
sparkdl.transformers
1214
sparkdl.udf
1315
sparkdl.utils

python/sparkdl/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
# limitations under the License.
1414
#
1515

16+
from .graph.input import TFInputGraph
1617
from .image.imageIO import imageSchema, imageType, readImages
1718
from .transformers.keras_image import KerasImageFileTransformer
1819
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
1920
from .transformers.tf_image import TFImageTransformer
21+
from .transformers.tf_tensor import TFTransformer
2022
from .transformers.utils import imageInputPlaceholder
2123

24+
2225
__all__ = [
2326
'imageSchema', 'imageType', 'readImages',
24-
'TFImageTransformer',
25-
'DeepImagePredictor', 'DeepImageFeaturizer',
26-
'KerasImageFileTransformer',
27+
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
28+
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer',
2729
'imageInputPlaceholder']

python/sparkdl/graph/builder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
4747
self.graph = graph or tf.Graph()
4848
self.sess = tf.Session(graph=self.graph)
4949
if using_keras:
50+
self.using_keras = True
5051
self.keras_prev_sess = K.get_session()
5152
else:
53+
self.using_keras = False
5254
self.keras_prev_sess = None
5355

5456
def __enter__(self):
55-
self.sess.as_default()
5657
self.sess.__enter__()
57-
if self.keras_prev_sess is not None:
58+
if self.using_keras:
5859
K.set_session(self.sess)
5960
return self
6061

6162
def __exit__(self, *args):
62-
if self.keras_prev_sess is not None:
63+
if self.using_keras:
6364
K.set_session(self.keras_prev_sess)
6465
self.sess.__exit__(*args)
6566

@@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True):
8788
else:
8889
gdef = self.graph.as_graph_def(add_shapes=True)
8990
return GraphFunction(graph_def=gdef,
90-
input_names=[tfx.validated_input(self.graph, elem) for elem in inputs],
91-
output_names=[tfx.validated_output(self.graph, elem) for elem in outputs])
91+
input_names=[tfx.validated_input(elem, self.graph) for elem in inputs],
92+
output_names=[tfx.validated_output(elem, self.graph) for elem in outputs])
9293

9394
def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
9495
"""
@@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k
130131
return_elements=gfn.output_names,
131132
name=scope_name,
132133
**gdef_kargs)
133-
feeds = [tfx.get_tensor(self.graph, name) for name in input_names]
134-
fetches = [tfx.get_tensor(self.graph, name) for name in output_names]
134+
feeds = [tfx.get_tensor(name, self.graph) for name in input_names]
135+
fetches = [tfx.get_tensor(name, self.graph) for name in output_names]
135136
return (feeds, fetches)
136137

137138

@@ -233,7 +234,7 @@ def fromList(cls, functions):
233234
_, first_gfn = functions[0]
234235
feeds, _ = issn.importGraphFunction(first_gfn, prefix='')
235236
for tnsr in feeds:
236-
name = tfx.op_name(issn.graph, tnsr)
237+
name = tfx.op_name(tnsr, issn.graph)
237238
first_input_info.append((tnsr.dtype, tnsr.shape, name))
238239
# TODO: make sure that this graph is not reused to prevent name conflict
239240
# Report error if the graph is not manipulated by anyone else
@@ -268,4 +269,3 @@ def fromList(cls, functions):
268269
gfn = issn.asGraphFunction(first_inputs, last_outputs)
269270

270271
return gfn
271-

0 commit comments

Comments
 (0)