Skip to content

Commit 2f42042

Browse files
hyunn9973copybara-github
authored andcommitted
Add MSLE and RMSLE.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#134 from google:new f680143 PiperOrigin-RevId: 842027777
1 parent ba88810 commit 2f42042

6 files changed

Lines changed: 165 additions & 0 deletions

File tree

src/metrax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@
3333
MAE = regression_metrics.MAE
3434
MRR = ranking_metrics.MRR
3535
MSE = regression_metrics.MSE
36+
MSLE = regression_metrics.MSLE
3637
NDCGAtK = ranking_metrics.NDCGAtK
3738
Perplexity = nlp_metrics.Perplexity
3839
Precision = classification_metrics.Precision
3940
PrecisionAtK = ranking_metrics.PrecisionAtK
4041
PSNR = image_metrics.PSNR
4142
RMSE = regression_metrics.RMSE
43+
RMSLE = regression_metrics.RMSLE
4244
RSQUARED = regression_metrics.RSQUARED
4345
Recall = classification_metrics.Recall
4446
RecallAtK = ranking_metrics.RecallAtK
@@ -63,12 +65,14 @@
6365
"MAE",
6466
"MRR",
6567
"MSE",
68+
"MSLE",
6669
"NDCGAtK",
6770
"Perplexity",
6871
"Precision",
6972
"PrecisionAtK",
7073
"PSNR",
7174
"RMSE",
75+
"RMSLE",
7276
"RSQUARED",
7377
"Recall",
7478
"RecallAtK",

src/metrax/metrax_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ class MetraxTest(parameterized.TestCase):
140140
metrax.MSE,
141141
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
142142
),
143+
(
144+
'msle',
145+
metrax.MSLE,
146+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
147+
),
148+
143149
(
144150
'ndcgAtK',
145151
metrax.NDCGAtK,
@@ -182,6 +188,11 @@ class MetraxTest(parameterized.TestCase):
182188
metrax.RMSE,
183189
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
184190
),
191+
(
192+
'rmsle',
193+
metrax.RMSLE,
194+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
195+
),
185196
(
186197
'rsquared',
187198
metrax.RSQUARED,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
MAE = nnx_metrics.MAE
2828
MRR = nnx_metrics.MRR
2929
MSE = nnx_metrics.MSE
30+
MSLE = nnx_metrics.MSLE
3031
NDCGAtK = nnx_metrics.NDCGAtK
3132
Perplexity = nnx_metrics.Perplexity
3233
Precision = nnx_metrics.Precision
3334
PrecisionAtK = nnx_metrics.PrecisionAtK
3435
PSNR = nnx_metrics.PSNR
3536
RMSE = nnx_metrics.RMSE
37+
RMSLE = nnx_metrics.RMSLE
3638
RSQUARED = nnx_metrics.RSQUARED
3739
Recall = nnx_metrics.Recall
3840
RecallAtK = nnx_metrics.RecallAtK

src/metrax/nnx/nnx_metrics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ def __init__(self):
107107
super().__init__(metrax.MSE)
108108

109109

110+
class MSLE(NnxWrapper):
111+
"""An NNX class for the Metrax metric MSLE."""
112+
113+
def __init__(self):
114+
super().__init__(metrax.MSLE)
115+
116+
110117
class NDCGAtK(NnxWrapper):
111118
"""An NNX class for the Metrax metric NDCGAtK."""
112119

@@ -163,6 +170,13 @@ def __init__(self):
163170
super().__init__(metrax.RMSE)
164171

165172

173+
class RMSLE(NnxWrapper):
174+
"""An NNX class for the Metrax metric RMSLE."""
175+
176+
def __init__(self):
177+
super().__init__(metrax.RMSLE)
178+
179+
166180
class RougeL(NnxWrapper):
167181
"""An NNX class for the Metrax metric RougeL."""
168182

src/metrax/regression_metrics.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,68 @@ def compute(self) -> jax.Array:
160160
return jnp.sqrt(super().compute())
161161

162162

163+
@flax.struct.dataclass
164+
class MSLE(base.Average):
165+
r"""Computes the mean squared logarithmic error for regression problems given `predictions` and `labels`.
166+
167+
The mean squared logarithmic error is defined as:
168+
169+
.. math::
170+
MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2
171+
172+
where:
173+
- :math:`y_i` are true values
174+
- :math:`\hat{y}_i` are predictions
175+
- :math:`N` is the number of samples
176+
"""
177+
178+
@classmethod
179+
def from_model_output(
180+
cls,
181+
predictions: jax.Array,
182+
labels: jax.Array,
183+
) -> 'MSLE':
184+
"""Updates the metric.
185+
186+
Args:
187+
predictions: A floating point 1D vector representing the prediction
188+
generated from the model. The shape should be (batch_size,).
189+
labels: True value. The shape should be (batch_size,).
190+
191+
Returns:
192+
Updated MSLE metric. The shape should be a single scalar.
193+
"""
194+
log_predictions = jnp.log(predictions + 1)
195+
log_labels = jnp.log(labels + 1)
196+
squared_error = jnp.square(log_predictions - log_labels)
197+
count = jnp.ones_like(labels, dtype=jnp.int32)
198+
return cls(
199+
total=squared_error.sum(),
200+
count=count.sum(),
201+
)
202+
203+
204+
@flax.struct.dataclass
205+
class RMSLE(MSLE):
206+
r"""Computes the root mean squared logarithmic error for regression problems given `predictions` and `labels`.
207+
208+
The root mean squared logarithmic error is defined as:
209+
210+
.. math::
211+
RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N}
212+
(ln(y_i + 1) - ln(\hat{y}_i + 1))^2
213+
}
214+
215+
where:
216+
- :math:`y_i` are true values
217+
- :math:`\hat{y}_i` are predictions
218+
- :math:`N` is the number of samples
219+
"""
220+
221+
def compute(self) -> jax.Array:
222+
return jnp.sqrt(super().compute())
223+
224+
163225
@flax.struct.dataclass
164226
class RSQUARED(clu_metrics.Metric):
165227
r"""Computes the r-squared score of a scalar or a batch of tensors.

src/metrax/regression_metrics_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,78 @@ def test_rmse(self, y_true, y_pred, sample_weights):
230230
atol=atol,
231231
)
232232

233+
@parameterized.named_parameters(
234+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
235+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
236+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
237+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
238+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
239+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
240+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
241+
)
242+
def test_msle(self, y_true, y_pred):
243+
"""Test that `MSLE` Metric computes correct values."""
244+
y_true = y_true.astype(y_pred.dtype)
245+
y_pred = y_pred.astype(y_true.dtype)
246+
247+
metric = None
248+
for labels, logits in zip(y_true, y_pred):
249+
update = metrax.MSLE.from_model_output(
250+
predictions=logits,
251+
labels=labels,
252+
)
253+
metric = update if metric is None else metric.merge(update)
254+
255+
expected = sklearn_metrics.mean_squared_log_error(
256+
y_true.astype('float32').flatten(),
257+
y_pred.astype('float32').flatten(),
258+
)
259+
# Use lower tolerance for lower precision dtypes.
260+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
261+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
262+
np.testing.assert_allclose(
263+
metric.compute(),
264+
expected,
265+
rtol=rtol,
266+
atol=atol,
267+
)
268+
269+
@parameterized.named_parameters(
270+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
271+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
272+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
273+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
274+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
275+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
276+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
277+
)
278+
def test_rmsle(self, y_true, y_pred):
279+
"""Test that `RMSLE` Metric computes correct values."""
280+
y_true = y_true.astype(y_pred.dtype)
281+
y_pred = y_pred.astype(y_true.dtype)
282+
283+
metric = None
284+
for labels, logits in zip(y_true, y_pred):
285+
update = metrax.RMSLE.from_model_output(
286+
predictions=logits,
287+
labels=labels,
288+
)
289+
metric = update if metric is None else metric.merge(update)
290+
291+
expected = sklearn_metrics.root_mean_squared_log_error(
292+
y_true.astype('float32').flatten(),
293+
y_pred.astype('float32').flatten(),
294+
)
295+
# Use lower tolerance for lower precision dtypes.
296+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
297+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
298+
np.testing.assert_allclose(
299+
metric.compute(),
300+
expected,
301+
rtol=rtol,
302+
atol=atol,
303+
)
304+
233305
@parameterized.named_parameters(
234306
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
235307
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),

0 commit comments

Comments
 (0)