1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414#
15-
1615"""
1716Some parts are copied from pyspark.ml.param.shared and some are complementary
1817to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being
1918private APIs.
2019"""
21-
20+ import textwrap
2221from functools import wraps
2322
2423from pyspark .ml .param import Param , Params , TypeConverters
2524
2625from sparkdl .param .converters import SparkDLTypeConverters
2726
28-
2927########################################################
3028# Copied from PySpark for backward compatibility.
3129# They first appeared in Apache Spark version 2.1.1.
3230########################################################
3331
32+
3433def keyword_only (func ):
3534 """
3635 A decorator that forces keyword arguments in the wrapped method
@@ -54,8 +53,8 @@ class HasInputCol(Params):
5453 Mixin for param inputCol: input column name.
5554 """
5655
57- inputCol = Param (
58- Params . _dummy (), "inputCol" , "input column name." , typeConverter = TypeConverters .toString )
56+ inputCol = Param (Params . _dummy (), "inputCol" , "input column name." ,
57+ typeConverter = TypeConverters .toString )
5958
6059 def setInputCol (self , value ):
6160 """
@@ -75,8 +74,8 @@ class HasOutputCol(Params):
7574 Mixin for param outputCol: output column name.
7675 """
7776
78- outputCol = Param (
79- Params . _dummy (), "outputCol" , "output column name." , typeConverter = TypeConverters .toString )
77+ outputCol = Param (Params . _dummy (), "outputCol" , "output column name." ,
78+ typeConverter = TypeConverters .toString )
8079
8180 def __init__ (self ):
8281 super (HasOutputCol , self ).__init__ ()
@@ -94,6 +93,7 @@ def getOutputCol(self):
9493 """
9594 return self .getOrDefault (self .outputCol )
9695
96+
9797########################################################
9898# New in sparkdl
9999########################################################
@@ -196,8 +196,7 @@ class HasOutputMapping(Params):
196196 """
197197 Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs
198198 """
199- outputMapping = Param (Params ._dummy (),
200- "outputMapping" ,
199+ outputMapping = Param (Params ._dummy (), "outputMapping" ,
201200 "Mapping output :class:`tf.Operation` names to DataFrame column names" ,
202201 typeConverter = SparkDLTypeConverters .asTensorNameToColumnMap )
203202
@@ -212,8 +211,7 @@ class HasInputMapping(Params):
212211 """
213212 Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs
214213 """
215- inputMapping = Param (Params ._dummy (),
216- "inputMapping" ,
214+ inputMapping = Param (Params ._dummy (), "inputMapping" ,
217215 "Mapping input DataFrame column names to :class:`tf.Operation` names" ,
218216 typeConverter = SparkDLTypeConverters .asColumnToTensorNameMap )
219217
@@ -228,9 +226,15 @@ class HasTFHParams(Params):
228226 """
229227 Mixin for TensorFlow model hyper-parameters
230228 """
231- tfHParams = Param (Params ._dummy (),
232- "hparams" ,
233- "instance of :class:`tf.contrib.training.HParams`, a key-value map-like object" ,
229+ tfHParams = Param (Params ._dummy (), "hparams" ,
230+ textwrap .dedent ("""\
231+ instance of :class:`tf.contrib.training.HParams`, a namespace-like
232+ key-value object, storing parameters to be used to define the final
233+ TensorFlow graph for the Transformer.
234+
235+ Currently accepted values are:
236+ - `batch_size`: number of samples provided to the inference graph
237+ during each evaluation function call.""" ,
234238 typeConverter = SparkDLTypeConverters .toTFHParams )
235239
236240 def setTFHParams (self , value ):
0 commit comments