@@ -251,7 +251,7 @@ def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
251251 ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 , 1.0 ))
252252 if not ok :
253253 raise ValueError (
254- "`elementwise_loss` must accept (prediction, target, weights ) when `weights` is passed to `fit`."
254+ "`elementwise_loss` must accept (prediction, target, weight ) when `weights` is passed to `fit`."
255255 )
256256 else :
257257 ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 ))
@@ -274,16 +274,17 @@ def _validate_custom_objective(
274274 raise ValueError (f"`{ knob } ` must evaluate to a callable Julia function." )
275275
276276 methods = jl .collect (jl .methods (custom_objective ))
277- accepts_three_args = any (
278- (not bool (m .isva ) and int (m .nargs ) == 4 ) or (bool (m .isva ) and int (m .nargs ) <= 4 )
279- for m in methods
280- )
281- appears_elementwise = any (
282- (not bool (m .isva ) and int (m .nargs ) == 3 ) or (bool (m .isva ) and int (m .nargs ) <= 3 )
283- for m in methods
284- )
285277
286- if not accepts_three_args and appears_elementwise :
278+ def _accepts_npos (m , npos : int ) -> bool :
279+ required_npos = int (m .nargs ) - 1
280+ if bool (m .isva ):
281+ return required_npos <= npos
282+ return required_npos == npos
283+
284+ accepts_three_args = any (_accepts_npos (m , 3 ) for m in methods )
285+ accepts_two_args = any (_accepts_npos (m , 2 ) for m in methods )
286+
287+ if not accepts_three_args and accepts_two_args :
287288 msg = (
288289 f"`{ knob } ` must have signature like { signature } . "
289290 f"If you intended an elementwise loss, use `{ elementwise_alternative } `."
0 commit comments