@@ -239,12 +239,11 @@ def _jl_is_nothing(value):
239239 return bool (jl .isnothing (value ))
240240
241241
242- def _validate_elementwise_loss (custom_loss ) -> None :
243- """Validate that a Julia `elementwise_loss` is callable with 2 or 3 args .
242+ def _validate_elementwise_loss (custom_loss , * , has_weights : bool ) -> None :
243+ """Validate that a Julia `elementwise_loss` is callable.
244244
245- Expected signatures:
246- - (prediction, target)
247- - (prediction, target, weights)
245+ We require exactly 2 args unless the user passed `weights=` to fit,
246+ in which case we require 3 args.
248247 """
249248
250249 # This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
@@ -255,18 +254,19 @@ def _validate_elementwise_loss(custom_loss) -> None:
255254 if not jl_is_function (custom_loss ):
256255 return
257256
258- methods = jl .collect (jl .methods (custom_loss ))
259- ok = any (
260- (not bool (m .isva ) and int (m .nargs ) in {3 , 4 })
261- or (bool (m .isva ) and int (m .nargs ) <= 4 )
262- for m in methods
263- )
264- if not ok :
265- raise ValueError (
266- "`elementwise_loss` must have signature (prediction, target) or "
267- "(prediction, target, weights). If you intended a full objective, use "
268- "`loss_function` or `loss_function_expression`."
269- )
257+ if has_weights :
258+ ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 , 1.0 ))
259+ if not ok :
260+ raise ValueError (
261+ "`elementwise_loss` must accept (prediction, target, weights) when `weights` is passed to `fit`."
262+ )
263+ else :
264+ ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 ))
265+ if not ok :
266+ raise ValueError (
267+ "`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
268+ "`loss_function` or `loss_function_expression`."
269+ )
270270
271271
272272def _validate_custom_objective (
@@ -2129,7 +2129,7 @@ def _run(
21292129 else "nothing"
21302130 )
21312131 if self .elementwise_loss is not None :
2132- _validate_elementwise_loss (custom_loss )
2132+ _validate_elementwise_loss (custom_loss , has_weights = weights is not None )
21332133
21342134 custom_full_objective = jl .seval (
21352135 str (self .loss_function ) if self .loss_function is not None else "nothing"
0 commit comments