2424
2525__all__ = ['SparkDLTypeConverters' ]
2626
27- def _get_strict_tensor_name (_maybe_tnsr_name ):
28- assert isinstance (_maybe_tnsr_name , six .string_types ), \
29- "must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
30- assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
31- "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
32- return _maybe_tnsr_name
33-
34- def _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = True ):
35- if isinstance (value , dict ):
36- strs_pair_seq = []
37- for k , v in value .items ():
38- # Check if the non-tensor value is of string type
39- _non_tnsr_str_val = v if is_key_tf_tensor else k
40- if not isinstance (_non_tnsr_str_val , six .string_types ):
41- err_msg = 'expect string type for {}, but got {}'
42- raise TypeError (err_msg .format (_non_tnsr_str_val , type (_non_tnsr_str_val )))
43-
44- # Check if the tensor name is actually valid
45- try :
46- if is_key_tf_tensor :
47- _pair = (_get_strict_tensor_name (k ), v )
48- else :
49- _pair = (k , _get_strict_tensor_name (v ))
50- except Exception as exc :
51- err_msg = "Can NOT convert {} (type {}) to tf.Tensor name: {}"
52- _not_tf_op = k if is_key_tf_tensor else v
53- raise TypeError (err_msg .format (_not_tf_op , type (_not_tf_op ), exc ))
54-
55- strs_pair_seq .append (_pair )
56-
57- return sorted (strs_pair_seq )
58-
59- if is_key_tf_tensor :
60- raise TypeError ("Could not convert %s to tf.Tensor name to str mapping" % type (value ))
61- else :
62- raise TypeError ("Could not convert %s to str to tf.Tensor name mapping" % type (value ))
63-
6427
6528class SparkDLTypeConverters (object ):
29+ """
30+ .. note:: DeveloperApi
31+
32+ Factory methods for common type conversion functions for :py:func:`Param.typeConverter`.
33+ These methods are similar to :py:class:`spark.ml.param.TypeConverters`.
34+ They provide support for the `Params` types introduced in Spark Deep Learning Pipelines.
35+ """
6636 @staticmethod
6737 def toTFGraph (value ):
6838 if isinstance (value , tf .Graph ):
@@ -72,49 +42,114 @@ def toTFGraph(value):
7242
7343 @staticmethod
7444 def asColumnToTensorNameMap (value ):
75- return _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = False )
45+ """
46+ Convert a value to a column name to :py:obj:`tf.Tensor` name mapping
47+ as a sorted list of string pairs, if possible.
48+ """
49+ if isinstance (value , dict ):
50+ strs_pair_seq = []
51+ for _maybe_col_name , _maybe_tnsr_name in value .items ():
52+ # Check if the non-tensor value is of string type
53+ _col_name = _get_strict_col_name (_maybe_col_name )
54+ # Check if the tensor name is actually valid
55+ _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
56+ strs_pair_seq .append ((_col_name , _tnsr_name ))
57+
58+ return sorted (strs_pair_seq )
59+
60+ err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping"
61+ raise TypeError (err_msg .format (type (value ), value ))
7662
7763 @staticmethod
7864 def asTensorNameToColumnMap (value ):
79- return _try_convert_tf_tensor_mapping (value , is_key_tf_tensor = True )
65+ """
66+ Convert a value to a :py:obj:`tf.Tensor` name to column name mapping
67+ as a sorted list of string pairs, if possible.
68+ """
69+ if isinstance (value , dict ):
70+ strs_pair_seq = []
71+ for _maybe_tnsr_name , _maybe_col_name in value .items ():
72+ # Check if the non-tensor value is of string type
73+ _col_name = _get_strict_col_name (_maybe_col_name )
74+ # Check if the tensor name is actually valid
75+ _tnsr_name = _get_strict_tensor_name (_maybe_tnsr_name )
76+ strs_pair_seq .append ((_tnsr_name , _col_name ))
77+
78+ return sorted (strs_pair_seq )
79+
80+ err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping"
81+ raise TypeError (err_msg .format (type (value ), value ))
8082
8183 @staticmethod
8284 def toTFHParams (value ):
85+ """ Convert a value to a :py:obj:`tf.contrib.training.HParams` object, if possible. """
8386 if isinstance (value , tf .contrib .training .HParams ):
8487 return value
8588 else :
8689 raise TypeError ("Could not convert %s to TensorFlow HParams" % type (value ))
8790
8891 @staticmethod
8992 def toStringOrTFTensor (value ):
93+ """ Convert a value to a str or a :py:obj:`tf.Tensor` object, if possible. """
9094 if isinstance (value , tf .Tensor ):
9195 return value
92- else :
93- try :
94- return TypeConverters . toString ( value )
95- except TypeError :
96- raise TypeError ("Could not convert %s to tensorflow.Tensor or str" % type (value ))
96+ try :
97+ return TypeConverters . toString ( value )
98+ except Exception as exc :
99+ err_msg = "Could not convert [type {}] {} to tf.Tensor or str. {}"
100+ raise TypeError (err_msg . format ( type (value ), value , exc ))
97101
98102 @staticmethod
99103 def supportedNameConverter (supportedList ):
104+ """
105+ Create a converter that try to check if a value is part of the supported list.
106+
107+ :param supportedList: list, containing supported objects.
108+ :return: a converter that try to convert a value if it is part of the `supportedList`.
109+ """
100110 def converter (value ):
101111 if value in supportedList :
102112 return value
103- else :
104- raise TypeError ("%s %s is not in the supported list." % type (value ), str (value ))
113+ err_msg = "[type {}] {} is not in the supported list: {}"
114+ raise TypeError (err_msg . format ( type (value ), str (value ), supportedList ))
105115
106116 return converter
107117
108118 @staticmethod
109119 def toKerasLoss (value ):
120+ """ Convert a value to a name of Keras loss function, if possible """
110121 if kmutil .is_valid_loss_function (value ):
111122 return value
112- raise ValueError (
113- "Named loss not supported in Keras: {} type({})" .format (value , type (value )))
123+ err_msg = "Named loss not supported in Keras: [type {}] {}"
124+ raise ValueError ( err_msg .format (type (value ), value ))
114125
115126 @staticmethod
116127 def toKerasOptimizer (value ):
128+ """ Convert a value to a name of Keras optimizer, if possible """
117129 if kmutil .is_valid_optimizer (value ):
118130 return value
119- raise TypeError (
120- "Named optimizer not supported in Keras: {} type({})" .format (value , type (value )))
131+ err_msg = "Named optimizer not supported in Keras: [type {}] {}"
132+ raise TypeError (err_msg .format (type (value ), value ))
133+
134+
135+ def _get_strict_tensor_name (_maybe_tnsr_name ):
136+ """ Check if the input is a valid tensor name """
137+ try :
138+ assert isinstance (_maybe_tnsr_name , six .string_types ), \
139+ "must provide a strict tensor name as input, but got {}" .format (type (_maybe_tnsr_name ))
140+ assert tfx .as_tensor_name (_maybe_tnsr_name ) == _maybe_tnsr_name , \
141+ "input {} must be a valid tensor name" .format (_maybe_tnsr_name )
142+ except Exception as exc :
143+ err_msg = "Can NOT convert [type {}] {} to tf.Tensor name: {}"
144+ raise TypeError (err_msg .format (type (_maybe_tnsr_name ), _maybe_tnsr_name , exc ))
145+ else :
146+ return _maybe_tnsr_name
147+
148+
149+ def _get_strict_col_name (_maybe_col_name ):
150+ """ Check if the given colunm name is a valid column name """
151+ # We only check if the column name candidate is a string type
152+ if not isinstance (_maybe_col_name , six .string_types ):
153+ err_msg = 'expect string type but got type {} for {}'
154+ raise TypeError (err_msg .format (type (_maybe_col_name ), _maybe_col_name ))
155+ return _maybe_col_name
0 commit comments