Skip to content

Commit 1365723

Browse files
committed
address comments
1 parent 146114b commit 1365723

2 files changed

Lines changed: 29 additions & 29 deletions

File tree

python/sparkdl/transformers/tf_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
keyword_only, HasInputMapping, HasOutputMapping, SparkDLTypeConverters,
3030
HasTFGraph, HasTFHParams)
3131

32-
__all__ = ['TFTensorTransformer']
32+
__all__ = ['TFModelTransformer']
3333

3434
logger = logging.getLogger('sparkdl')
3535

36-
class TFTensorTransformer(Transformer, HasTFGraph, HasTFHParams, HasInputMapping, HasOutputMapping):
36+
class TFModelTransformer(Transformer, HasTFGraph, HasTFHParams, HasInputMapping, HasOutputMapping):
3737
"""
3838
Applies the TensorFlow graph to the array column in DataFrame.
3939
@@ -50,7 +50,7 @@ def __init__(self, inputMapping=None, outputMapping=None, tfGraph=None, hparams=
5050
"""
5151
__init__(self, inputMapping=None, outputMapping=None, tfGraph=None, hparams=None)
5252
"""
53-
super(TFTensorTransformer, self).__init__()
53+
super(TFModelTransformer, self).__init__()
5454
kwargs = self._input_kwargs
5555
self.setParams(**kwargs)
5656

@@ -59,7 +59,7 @@ def setParams(self, inputMapping=None, outputMapping=None, tfGraph=None, hparams
5959
"""
6060
setParams(self, inputMapping=None, outputMapping=None, tfGraph=None, hparams=None)
6161
"""
62-
super(TFTensorTransformer, self).__init__()
62+
super(TFModelTransformer, self).__init__()
6363
kwargs = self._input_kwargs
6464
return self._set(**kwargs)
6565

python/tests/transformers/tf_tensor_test.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from sparkdl.graph.builder import IsolatedSession
2424
import sparkdl.graph.utils as tfx
25-
from sparkdl.transformers.tf_tensor import TFTensorTransformer
25+
from sparkdl.transformers.tf_tensor import TFModelTransformer
2626

2727
from ..tests import SparkDLTestCase
2828

@@ -31,7 +31,7 @@ def grab_df_arr(df, output_col):
3131
return np.array([row.asDict()[output_col]
3232
for row in df.select(output_col).toLocalIterator()])
3333

34-
class TFOneDimTransformerTest(SparkDLTestCase):
34+
class TFModelTransformerTest(SparkDLTestCase):
3535

3636
def _get_rand_vec_df(self, num_rows, vec_size):
3737
return self.session.createDataFrame(
@@ -60,13 +60,13 @@ def test_simple(self):
6060
out_ref = np.hstack(_results)
6161

6262
# Apply the transform
63-
transfomer = TFTensorTransformer(tfGraph=graph,
64-
inputMapping={
65-
'vec': x
66-
},
67-
outputMapping={
68-
z: 'outCol'
69-
})
63+
transfomer = TFModelTransformer(tfGraph=graph,
64+
inputMapping={
65+
'vec': x
66+
},
67+
outputMapping={
68+
z: 'outCol'
69+
})
7070
final_df = transfomer.transform(analyzed_df)
7171
out_tgt = grab_df_arr(final_df, 'outCol')
7272

@@ -105,15 +105,15 @@ def test_multi_io(self):
105105
q_out_ref = np.hstack(q_out_ref)
106106

107107
# Apply the transform
108-
transfomer = TFTensorTransformer(tfGraph=graph,
109-
inputMapping={
110-
'vec_x': x,
111-
'vec_y': y
112-
},
113-
outputMapping={
114-
p: 'out_p',
115-
q: 'out_q'
116-
})
108+
transfomer = TFModelTransformer(tfGraph=graph,
109+
inputMapping={
110+
'vec_x': x,
111+
'vec_y': y
112+
},
113+
outputMapping={
114+
p: 'out_p',
115+
q: 'out_q'
116+
})
117117
final_df = transfomer.transform(analyzed_df)
118118
p_out_tgt = grab_df_arr(final_df, 'out_p')
119119
q_out_tgt = grab_df_arr(final_df, 'out_q')
@@ -169,13 +169,13 @@ def test_map_blocks_graph(self):
169169
arr_ref = grab_df_arr(final_df, output_col)
170170

171171
# Using the Transformer
172-
transformer = TFTensorTransformer(tfGraph=gfn,
173-
inputMapping={
174-
input_col: feeds[0]
175-
},
176-
outputMapping={
177-
fetches[0]: output_col
178-
})
172+
transformer = TFModelTransformer(tfGraph=gfn,
173+
inputMapping={
174+
input_col: feeds[0]
175+
},
176+
outputMapping={
177+
fetches[0]: output_col
178+
})
179179
transformed_df = transformer.transform(analyzed_df)
180180

181181
arr_tgt = grab_df_arr(transformed_df, output_col)

0 commit comments

Comments
 (0)