Skip to content

Commit 940a73b

Browse files
test: cover loss validation branches
1 parent f056a72 commit 940a73b

1 file changed

Lines changed: 52 additions & 1 deletion

File tree

pysr/test/test_main.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
_check_assertions,
4242
_process_constraints,
4343
_suggest_keywords,
44+
_validate_custom_objective,
45+
_validate_elementwise_loss,
4446
idx_model_selection,
4547
)
4648

@@ -250,14 +252,63 @@ def test_elementwise_loss_wrong_signature_errors_early(self):
250252
verbosity=0,
251253
temp_equation_file=True,
252254
binary_operators=["+"],
253-
elementwise_loss="loss(a) = a",
255+
elementwise_loss="myloss_bad_arity(a) = a",
254256
)
255257
X = np.array([[0.0], [1.0]])
256258
y = np.array([0.0, 1.0])
257259
with self.assertRaises(ValueError) as cm:
258260
model.fit(X, y)
259261
self.assertIn("elementwise_loss", str(cm.exception))
260262

263+
def test_elementwise_loss_with_weights_requires_three_args(self):
264+
model = PySRRegressor(
265+
niterations=1,
266+
populations=1,
267+
procs=0,
268+
progress=False,
269+
verbosity=0,
270+
temp_equation_file=True,
271+
binary_operators=["+"],
272+
elementwise_loss="myloss2(prediction, target) = (prediction - target)^2",
273+
)
274+
X = np.array([[0.0], [1.0]])
275+
y = np.array([0.0, 1.0])
276+
weights = np.array([1.0, 1.0])
277+
with self.assertRaises(ValueError) as cm:
278+
model.fit(X, y, weights=weights)
279+
self.assertIn("elementwise_loss", str(cm.exception))
280+
self.assertIn("weights", str(cm.exception))
281+
282+
def test_elementwise_loss_with_weights_accepts_three_args(self):
283+
model = PySRRegressor(
284+
niterations=1,
285+
populations=1,
286+
procs=0,
287+
progress=False,
288+
verbosity=0,
289+
temp_equation_file=True,
290+
binary_operators=["+"],
291+
elementwise_loss=(
292+
"myloss3(prediction, target, weights) = weights * (prediction - target)^2"
293+
),
294+
)
295+
X = np.array([[0.0], [1.0]])
296+
y = np.array([0.0, 1.0])
297+
weights = np.array([1.0, 1.0])
298+
model.fit(X, y, weights=weights)
299+
300+
def test_validation_helpers_allow_nothing(self):
301+
_validate_elementwise_loss(jl.seval("nothing"), has_weights=False)
302+
_validate_custom_objective(
303+
jl.seval("nothing"),
304+
knob="loss_function",
305+
signature="(tree, dataset, options)",
306+
elementwise_alternative="elementwise_loss",
307+
)
308+
309+
def test_validation_helpers_skip_nonfunction(self):
310+
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)
311+
261312
def test_loss_function_expression_elementwise_signature_errors_early(self):
262313
"""Validate `loss_function_expression` signature (expression, dataset, options)."""
263314
model = PySRRegressor(

0 commit comments

Comments
 (0)