Skip to content

Commit ad481a7

Browse files
committed
docs
1 parent 944e328 commit ad481a7

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

python/sparkdl/graph/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def as_tensor_name(tfobj_or_name):
104104
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
105105
"""
106106
if isinstance(tfobj_or_name, six.string_types):
107+
# If input is a string, assume it is a name and infer the corresponding tensor name.
108+
# WARNING: this depends on TensorFlow's tensor naming convention
107109
name = tfobj_or_name
108110
name_parts = name.split(":")
109111
assert len(name_parts) <= 2, name_parts
@@ -125,6 +127,8 @@ def as_op_name(tfobj_or_name):
125127
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
126128
"""
127129
if isinstance(tfobj_or_name, six.string_types):
130+
# If input is a string, assume it is a name and infer the corresponding operation name.
131+
# WARNING: this depends on TensorFlow's operation naming convention
128132
name = tfobj_or_name
129133
name_parts = name.split(":")
130134
assert len(name_parts) <= 2, name_parts

python/sparkdl/transformers/param.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,11 @@ class HasTFHParams(Params):
222222
# New in sparkdl
223223

224224
class HasOutputMapping(Params):
225+
"""
226+
Mixin for param outputMapping: ordered list of ('outputTensorName', 'outputColName') pairs
227+
"""
225228
outputMapping = Param(Params._dummy(), "outputMapping",
226-
"Name of output tensor in signature def",
229+
"Mapping output :class:`tf.Tensor` objects to DataFrame column names",
227230
typeConverter=SparkDLTypeConverters.asTensorToColumnMap)
228231

229232
def __init__(self):
@@ -237,8 +240,11 @@ def getOutputMapping(self):
237240

238241

239242
class HasInputMapping(Params):
243+
"""
244+
Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorName') pairs
245+
"""
240246
inputMapping = Param(Params._dummy(), "inputMapping",
241-
"Name of input tensor in signature def",
247+
"Mapping input DataFrame column names to :class:`tf.Tensor` objects",
242248
typeConverter=SparkDLTypeConverters.asColumnToTensorMap)
243249

244250
def __init__(self):
@@ -252,6 +258,9 @@ def getInputMapping(self):
252258

253259

254260
class HasTFGraph(Params):
261+
"""
262+
Mixin for param tfGraph: the :class:`tf.Graph` object that represents a TensorFlow computation.
263+
"""
255264
tfGraph = Param(Params._dummy(), "tfGraph",
256265
"TensorFlow Graph object",
257266
typeConverter=SparkDLTypeConverters.toTFGraph)

0 commit comments

Comments
 (0)