Skip to content

Commit 9ba1dea

Browse files
fix: clarify loss validator arity + error text
1 parent a132362 commit 9ba1dea

2 files changed

Lines changed: 31 additions & 10 deletions

File tree

pysr/sr.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
251251
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
252252
if not ok:
253253
raise ValueError(
254-
"`elementwise_loss` must accept (prediction, target, weights) when `weights` is passed to `fit`."
254+
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255255
)
256256
else:
257257
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
@@ -274,16 +274,17 @@ def _validate_custom_objective(
274274
raise ValueError(f"`{knob}` must evaluate to a callable Julia function.")
275275

276276
methods = jl.collect(jl.methods(custom_objective))
277-
accepts_three_args = any(
278-
(not bool(m.isva) and int(m.nargs) == 4) or (bool(m.isva) and int(m.nargs) <= 4)
279-
for m in methods
280-
)
281-
appears_elementwise = any(
282-
(not bool(m.isva) and int(m.nargs) == 3) or (bool(m.isva) and int(m.nargs) <= 3)
283-
for m in methods
284-
)
285277

286-
if not accepts_three_args and appears_elementwise:
278+
def _accepts_npos(m, npos: int) -> bool:
279+
required_npos = int(m.nargs) - 1
280+
if bool(m.isva):
281+
return required_npos <= npos
282+
return required_npos == npos
283+
284+
accepts_three_args = any(_accepts_npos(m, 3) for m in methods)
285+
accepts_two_args = any(_accepts_npos(m, 2) for m in methods)
286+
287+
if not accepts_three_args and accepts_two_args:
287288
msg = (
288289
f"`{knob}` must have signature like {signature}. "
289290
f"If you intended an elementwise loss, use `{elementwise_alternative}`."

pysr/test/test_main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,26 @@ def test_loss_function_valid_full_objective_runs(self):
242242
y = np.array([0.0, 1.0])
243243
model.fit(X, y)
244244

245+
def test_loss_function_varargs_objective_runs(self):
246+
model = PySRRegressor(
247+
niterations=1,
248+
populations=1,
249+
procs=0,
250+
progress=False,
251+
verbosity=0,
252+
temp_equation_file=True,
253+
binary_operators=["+"],
254+
loss_function="""
255+
begin
256+
varloss(tree, dataset, options...) = zero(eltype(dataset.y))
257+
varloss
258+
end
259+
""",
260+
)
261+
X = np.array([[0.0], [1.0]])
262+
y = np.array([0.0, 1.0])
263+
model.fit(X, y)
264+
245265
def test_elementwise_loss_wrong_signature_errors_early(self):
246266
"""Validate `elementwise_loss` signature (prediction, target[, weights])."""
247267
model = PySRRegressor(

0 commit comments

Comments
 (0)