|
41 | 41 | _check_assertions, |
42 | 42 | _process_constraints, |
43 | 43 | _suggest_keywords, |
| 44 | + _validate_custom_objective, |
| 45 | + _validate_elementwise_loss, |
44 | 46 | idx_model_selection, |
45 | 47 | ) |
46 | 48 |
|
@@ -250,14 +252,63 @@ def test_elementwise_loss_wrong_signature_errors_early(self): |
250 | 252 | verbosity=0, |
251 | 253 | temp_equation_file=True, |
252 | 254 | binary_operators=["+"], |
253 | | - elementwise_loss="loss(a) = a", |
| 255 | + elementwise_loss="myloss_bad_arity(a) = a", |
254 | 256 | ) |
255 | 257 | X = np.array([[0.0], [1.0]]) |
256 | 258 | y = np.array([0.0, 1.0]) |
257 | 259 | with self.assertRaises(ValueError) as cm: |
258 | 260 | model.fit(X, y) |
259 | 261 | self.assertIn("elementwise_loss", str(cm.exception)) |
260 | 262 |
|
| 263 | + def test_elementwise_loss_with_weights_requires_three_args(self): |
| 264 | + model = PySRRegressor( |
| 265 | + niterations=1, |
| 266 | + populations=1, |
| 267 | + procs=0, |
| 268 | + progress=False, |
| 269 | + verbosity=0, |
| 270 | + temp_equation_file=True, |
| 271 | + binary_operators=["+"], |
| 272 | + elementwise_loss="myloss2(prediction, target) = (prediction - target)^2", |
| 273 | + ) |
| 274 | + X = np.array([[0.0], [1.0]]) |
| 275 | + y = np.array([0.0, 1.0]) |
| 276 | + weights = np.array([1.0, 1.0]) |
| 277 | + with self.assertRaises(ValueError) as cm: |
| 278 | + model.fit(X, y, weights=weights) |
| 279 | + self.assertIn("elementwise_loss", str(cm.exception)) |
| 280 | + self.assertIn("weights", str(cm.exception)) |
| 281 | + |
| 282 | + def test_elementwise_loss_with_weights_accepts_three_args(self): |
| 283 | + model = PySRRegressor( |
| 284 | + niterations=1, |
| 285 | + populations=1, |
| 286 | + procs=0, |
| 287 | + progress=False, |
| 288 | + verbosity=0, |
| 289 | + temp_equation_file=True, |
| 290 | + binary_operators=["+"], |
| 291 | + elementwise_loss=( |
| 292 | + "myloss3(prediction, target, weights) = weights * (prediction - target)^2" |
| 293 | + ), |
| 294 | + ) |
| 295 | + X = np.array([[0.0], [1.0]]) |
| 296 | + y = np.array([0.0, 1.0]) |
| 297 | + weights = np.array([1.0, 1.0]) |
| 298 | + model.fit(X, y, weights=weights) |
| 299 | + |
| 300 | + def test_validation_helpers_allow_nothing(self): |
| 301 | + _validate_elementwise_loss(jl.seval("nothing"), has_weights=False) |
| 302 | + _validate_custom_objective( |
| 303 | + jl.seval("nothing"), |
| 304 | + knob="loss_function", |
| 305 | + signature="(tree, dataset, options)", |
| 306 | + elementwise_alternative="elementwise_loss", |
| 307 | + ) |
| 308 | + |
| 309 | + def test_validation_helpers_skip_nonfunction(self): |
| 310 | + _validate_elementwise_loss(jl.seval("1.0"), has_weights=False) |
| 311 | + |
261 | 312 | def test_loss_function_expression_elementwise_signature_errors_early(self): |
262 | 313 | """Validate `loss_function_expression` signature (expression, dataset, options).""" |
263 | 314 | model = PySRRegressor( |
|
0 commit comments