Skip to content

Commit f056a72

Browse files
fix: validate elementwise_loss arity based on weights
1 parent 4d10646 commit f056a72

1 file changed

Lines changed: 18 additions & 18 deletions

File tree

pysr/sr.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,11 @@ def _jl_is_nothing(value):
239239
return bool(jl.isnothing(value))
240240

241241

242-
def _validate_elementwise_loss(custom_loss) -> None:
243-
"""Validate that a Julia `elementwise_loss` is callable with 2 or 3 args.
242+
def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
243+
"""Validate that a Julia `elementwise_loss` is callable.
244244
245-
Expected signatures:
246-
- (prediction, target)
247-
- (prediction, target, weights)
245+
We require exactly 2 args unless the user passed `weights=` to fit,
246+
in which case we require 3 args.
248247
"""
249248

250249
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
@@ -255,18 +254,19 @@ def _validate_elementwise_loss(custom_loss) -> None:
255254
if not jl_is_function(custom_loss):
256255
return
257256

258-
methods = jl.collect(jl.methods(custom_loss))
259-
ok = any(
260-
(not bool(m.isva) and int(m.nargs) in {3, 4})
261-
or (bool(m.isva) and int(m.nargs) <= 4)
262-
for m in methods
263-
)
264-
if not ok:
265-
raise ValueError(
266-
"`elementwise_loss` must have signature (prediction, target) or "
267-
"(prediction, target, weights). If you intended a full objective, use "
268-
"`loss_function` or `loss_function_expression`."
269-
)
257+
if has_weights:
258+
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
259+
if not ok:
260+
raise ValueError(
261+
"`elementwise_loss` must accept (prediction, target, weights) when `weights` is passed to `fit`."
262+
)
263+
else:
264+
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
265+
if not ok:
266+
raise ValueError(
267+
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
268+
"`loss_function` or `loss_function_expression`."
269+
)
270270

271271

272272
def _validate_custom_objective(
@@ -2129,7 +2129,7 @@ def _run(
21292129
else "nothing"
21302130
)
21312131
if self.elementwise_loss is not None:
2132-
_validate_elementwise_loss(custom_loss)
2132+
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
21332133

21342134
custom_full_objective = jl.seval(
21352135
str(self.loss_function) if self.loss_function is not None else "nothing"

0 commit comments

Comments
 (0)