@@ -33,120 +33,131 @@ class SparkDLTypeConverters(object):
3333 These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
3434 They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
3535 """
36+
3637 @staticmethod
3738 def toTFGraph (value ):
38- if isinstance (value , tf .Graph ):
39- return value
40- else :
41- raise TypeError ("Could not convert %s to TensorFlow Graph" % type (value ))
39+ if not isinstance (value , tf .Graph ):
40+ raise TypeError ("Could not convert %s to tf.Graph" % type (value ))
41+ return value
4242
4343 @staticmethod
4444 def asColumnToTensorNameMap (value ):
4545 """
4646 Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
4747 as a sorted list of string pairs, if possible.
4848 """
49- if isinstance (value , dict ):
50- strs_pair_seq = []
51- for _maybe_col_name , _maybe_tnsr_name in value .items ():
52- # Check if the non-tensor value is of string type
53- _col_name = _get_strict_col_name (_maybe_col_name )
54- # Check if the tensor name is actually valid
55- _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
56- strs_pair_seq .append ((_col_name , _tnsr_name ))
49+ if not isinstance (value , dict ):
50+ err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
51+ raise TypeError (err_msg .format (type (value ), value ))
5752
58- return sorted (strs_pair_seq )
53+ # Convertion logic after quick type check
54+ strs_pair_seq = []
55+ for _maybe_col_name , _maybe_tnsr_name in value .items ():
56+ # Check if the non-tensor value is of string type
57+ _check_is_str (_maybe_col_name )
58+ # Check if the tensor name looks like a tensor name
59+ _check_is_tensor_name (_maybe_tnsr_name )
60+ strs_pair_seq .append ((_maybe_col_name , _maybe_tnsr_name ))
5961
60- err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61- raise TypeError (err_msg .format (type (value ), value ))
62+ return sorted (strs_pair_seq )
6263
6364 @staticmethod
6465 def asTensorNameToColumnMap (value ):
6566 """
6667 Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
6768 as a sorted list of string pairs, if possible.
6869 """
69- if isinstance (value , dict ):
70- strs_pair_seq = []
71- for _maybe_tnsr_name , _maybe_col_name in value .items ():
72- # Check if the non-tensor value is of string type
73- _col_name = _get_strict_col_name (_maybe_col_name )
74- # Check if the tensor name is actually valid
75- _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
76- strs_pair_seq .append ((_tnsr_name , _col_name ))
70+ if not isinstance (value , dict ):
71+ err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
72+ raise TypeError (err_msg .format (type (value ), value ))
7773
78- return sorted (strs_pair_seq )
74+ # Convertion logic after quick type check
75+ strs_pair_seq = []
76+ for _maybe_tnsr_name , _maybe_col_name in value .items ():
77+ # Check if the non-tensor value is of string type
78+ _check_is_str (_maybe_col_name )
79+ # Check if the tensor name looks like a tensor name
80+ _check_is_tensor_name (_maybe_tnsr_name )
81+ strs_pair_seq .append ((_maybe_tnsr_name , _maybe_col_name ))
7982
80- err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81- raise TypeError (err_msg .format (type (value ), value ))
83+ return sorted (strs_pair_seq )
8284
8385 @staticmethod
8486 def toTFHParams (value ):
8587 """ Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
86- if isinstance (value , tf .contrib .training .HParams ):
87- return value
88- else :
88+ if not isinstance (value , tf .contrib .training .HParams ):
8989 raise TypeError ("Could not convert %s to TensorFlow HParams" % type (value ))
9090
91+ return value
92+
9193 @staticmethod
92- def toStringOrTFTensor (value ):
94+ def toTFTensorName (value ):
9395 """ Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
9496 if isinstance (value , tf .Tensor ):
95- return value
97+ return value . name
9698 try :
99+ _check_is_tensor_name (value )
97100 return TypeConverters .toString (value )
98101 except Exception as exc :
99102 err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100103 raise TypeError (err_msg .format (type (value ), value , exc ))
101104
102105 @staticmethod
103- def supportedNameConverter (supportedList ):
106+ def buildCheckList (supportedList ):
104107 """
105108 Create a converter that try to check if a value is part of the supported list.
106109
107110 :param supportedList: list, containing supported objects.
108111 :return: a converter that try to convert a value if it is part of the `supportedList`.
109112 """
113+
110114 def converter (value ):
111- if value in supportedList :
112- return value
113- err_msg = "[type {}] {} is not in the supported list: {}"
114- raise TypeError (err_msg .format (type (value ), str (value ), supportedList ))
115+ if value not in supportedList :
116+ err_msg = "[type {}] {} is not in the supported list: {}"
117+ raise TypeError (err_msg .format (type (value ), str (value ), supportedList ))
118+
119+ return value
115120
116121 return converter
117122
118123 @staticmethod
119124 def toKerasLoss (value ):
120125 """ Convert a value to a name of Keras loss function, if possible """
121- if kmutil .is_valid_loss_function (value ):
122- return value
123- err_msg = "Named loss not supported in Keras: [type {}] {}"
124- raise ValueError (err_msg .format (type (value ), value ))
126+ # return early in for clarify as well as less indentation
127+ if not kmutil .is_valid_loss_function (value ):
128+ err_msg = "Named loss not supported in Keras: [type {}] {}"
129+ raise ValueError (err_msg .format (type (value ), value ))
130+
131+ return value
125132
126133 @staticmethod
127134 def toKerasOptimizer (value ):
128135 """ Convert a value to a name of Keras optimizer, if possible """
129- if kmutil .is_valid_optimizer (value ):
130- return value
131- err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132- raise TypeError (err_msg .format (type (value ), value ))
136+ if not kmutil .is_valid_optimizer (value ):
137+ err_msg = "Named optimizer not supported in Keras: [type {}] {}"
138+ raise TypeError (err_msg .format (type (value ), value ))
139+
140+ return value
133141
134142
135- def _get_strict_tensor_name (_maybe_tnsr_name ):
143+ def _check_is_tensor_name (_maybe_tnsr_name ):
136144 """ Check if the input is a valid tensor name """
137145 try :
138146 assert isinstance (_maybe_tnsr_name , six .string_types ), \
139147 "must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
140- assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
141- "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
148+
149+ # The check is taken from TensorFlow's NodeDef protocol buffer.
150+ # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/node_def.proto#L21-L25
151+ _ , src_idx = _maybe_tnsr_name .split (":" )
152+ _ = int (src_idx )
142153 except Exception as exc :
143154 err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144155 raise TypeError (err_msg .format (type (_maybe_tnsr_name ), _maybe_tnsr_name , exc ))
145156 else :
146157 return _maybe_tnsr_name
147158
148159
149- def _get_strict_col_name (_maybe_col_name ):
160+ def _check_is_str (_maybe_col_name ):
150161 """ Check if the given colunm name is a valid column name """
151162 # We only check if the column name candidate is a string type
152163 if not isinstance (_maybe_col_name , six .string_types ):
0 commit comments