@@ -202,11 +202,7 @@ class HasOutputMapping(Params):
202202 typeConverter = SparkDLTypeConverters .asTensorNameToColumnMap )
203203
204204 def setOutputMapping (self , value ):
205- # NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
206- # serializable TFInputGraph object once the inputMapping and outputMapping
207- # parameters are provided.
208- raise NotImplementedError (
209- "Please use the Transformer's constructor to assign `outputMapping` field." )
205+ return self ._set (outputMapping = value )
210206
211207 def getOutputMapping (self ):
212208 return self .getOrDefault (self .outputMapping )
@@ -222,11 +218,7 @@ class HasInputMapping(Params):
222218 typeConverter = SparkDLTypeConverters .asColumnToTensorNameMap )
223219
224220 def setInputMapping (self , value ):
225- # NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the
226- # serializable TFInputGraph object once the inputMapping and outputMapping
227- # parameters are provided.
228- raise NotImplementedError (
229- "Please use the Transformer's constructor to assigne `inputMapping` field." )
221+ return self ._set (inputMapping = value )
230222
231223 def getInputMapping (self ):
232224 return self .getOrDefault (self .inputMapping )
0 commit comments