2222
2323from sparkdl .graph .builder import IsolatedSession
2424import sparkdl .graph .utils as tfx
25- from sparkdl .transformers .tf_tensor import TFTensorTransformer
25+ from sparkdl .transformers .tf_tensor import TFModelTransformer
2626
2727from ..tests import SparkDLTestCase
2828
@@ -31,7 +31,7 @@ def grab_df_arr(df, output_col):
3131 return np .array ([row .asDict ()[output_col ]
3232 for row in df .select (output_col ).toLocalIterator ()])
3333
34- class TFOneDimTransformerTest (SparkDLTestCase ):
34+ class TFModelTransformerTest (SparkDLTestCase ):
3535
3636 def _get_rand_vec_df (self , num_rows , vec_size ):
3737 return self .session .createDataFrame (
@@ -60,13 +60,13 @@ def test_simple(self):
6060 out_ref = np .hstack (_results )
6161
6262 # Apply the transform
63- transfomer = TFTensorTransformer (tfGraph = graph ,
64- inputMapping = {
65- 'vec' : x
66- },
67- outputMapping = {
68- z : 'outCol'
69- })
63+ transfomer = TFModelTransformer (tfGraph = graph ,
64+ inputMapping = {
65+ 'vec' : x
66+ },
67+ outputMapping = {
68+ z : 'outCol'
69+ })
7070 final_df = transfomer .transform (analyzed_df )
7171 out_tgt = grab_df_arr (final_df , 'outCol' )
7272
@@ -105,15 +105,15 @@ def test_multi_io(self):
105105 q_out_ref = np .hstack (q_out_ref )
106106
107107 # Apply the transform
108- transfomer = TFTensorTransformer (tfGraph = graph ,
109- inputMapping = {
110- 'vec_x' : x ,
111- 'vec_y' : y
112- },
113- outputMapping = {
114- p : 'out_p' ,
115- q : 'out_q'
116- })
108+ transfomer = TFModelTransformer (tfGraph = graph ,
109+ inputMapping = {
110+ 'vec_x' : x ,
111+ 'vec_y' : y
112+ },
113+ outputMapping = {
114+ p : 'out_p' ,
115+ q : 'out_q'
116+ })
117117 final_df = transfomer .transform (analyzed_df )
118118 p_out_tgt = grab_df_arr (final_df , 'out_p' )
119119 q_out_tgt = grab_df_arr (final_df , 'out_q' )
@@ -169,13 +169,13 @@ def test_map_blocks_graph(self):
169169 arr_ref = grab_df_arr (final_df , output_col )
170170
171171 # Using the Transformer
172- transformer = TFTensorTransformer (tfGraph = gfn ,
173- inputMapping = {
174- input_col : feeds [0 ]
175- },
176- outputMapping = {
177- fetches [0 ]: output_col
178- })
172+ transformer = TFModelTransformer (tfGraph = gfn ,
173+ inputMapping = {
174+ input_col : feeds [0 ]
175+ },
176+ outputMapping = {
177+ fetches [0 ]: output_col
178+ })
179179 transformed_df = transformer .transform (analyzed_df )
180180
181181 arr_tgt = grab_df_arr (transformed_df , output_col )
0 commit comments