@@ -230,6 +230,78 @@ def test_rmse(self, y_true, y_pred, sample_weights):
230230 atol = atol ,
231231 )
232232
233+ @parameterized .named_parameters (
234+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
235+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
236+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
237+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 ),
238+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
239+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
240+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
241+ )
242+ def test_msle (self , y_true , y_pred ):
243+ """Test that `MSLE` Metric computes correct values."""
244+ y_true = y_true .astype (y_pred .dtype )
245+ y_pred = y_pred .astype (y_true .dtype )
246+
247+ metric = None
248+ for labels , logits in zip (y_true , y_pred ):
249+ update = metrax .MSLE .from_model_output (
250+ predictions = logits ,
251+ labels = labels ,
252+ )
253+ metric = update if metric is None else metric .merge (update )
254+
255+ expected = sklearn_metrics .mean_squared_log_error (
256+ y_true .astype ('float32' ).flatten (),
257+ y_pred .astype ('float32' ).flatten (),
258+ )
259+ # Use lower tolerance for lower precision dtypes.
260+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
261+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
262+ np .testing .assert_allclose (
263+ metric .compute (),
264+ expected ,
265+ rtol = rtol ,
266+ atol = atol ,
267+ )
268+
269+ @parameterized .named_parameters (
270+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
271+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
272+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
273+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 ),
274+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
275+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
276+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
277+ )
278+ def test_rmsle (self , y_true , y_pred ):
279+ """Test that `RMSLE` Metric computes correct values."""
280+ y_true = y_true .astype (y_pred .dtype )
281+ y_pred = y_pred .astype (y_true .dtype )
282+
283+ metric = None
284+ for labels , logits in zip (y_true , y_pred ):
285+ update = metrax .RMSLE .from_model_output (
286+ predictions = logits ,
287+ labels = labels ,
288+ )
289+ metric = update if metric is None else metric .merge (update )
290+
291+ expected = sklearn_metrics .root_mean_squared_log_error (
292+ y_true .astype ('float32' ).flatten (),
293+ y_pred .astype ('float32' ).flatten (),
294+ )
295+ # Use lower tolerance for lower precision dtypes.
296+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
297+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
298+ np .testing .assert_allclose (
299+ metric .compute (),
300+ expected ,
301+ rtol = rtol ,
302+ atol = atol ,
303+ )
304+
233305 @parameterized .named_parameters (
234306 ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
235307 ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
0 commit comments