1717from typing import Callable
1818
1919import jax
20- from jax .nn import softplus
2120import jax .numpy as jnp
22- from jax .scipy .special import logsumexp
2321from jaxopt ._src .projection import projection_simplex , projection_hypercube
2422
23+ from optax import losses as optax_losses
24+
2525
2626# Regression
2727
@@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
3939 References:
4040 https://en.wikipedia.org/wiki/Huber_loss
4141 """
42- abs_diff = jnp .abs (target - pred )
43- return jnp .where (abs_diff > delta ,
44- delta * (abs_diff - .5 * delta ),
45- 0.5 * abs_diff ** 2 )
42+ return optax_losses .huber_loss (pred , target , delta )
4643
4744# Binary classification.
4845
@@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
5653 Returns:
5754 loss value
5855 """
59- # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
60- # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
61- # where xlogx(proba) = proba * log(proba).
62- # Use -log sigmoid(logit) = softplus(-logit)
63- # and 1 - sigmoid(logit) = sigmoid(-logit).
64- return softplus (jnp .where (label , - logit , logit ))
56+ return optax_losses .sigmoid_binary_cross_entropy (
57+ jnp .asarray (logit ), jnp .asarray (label ))
6558
6659
6760def binary_sparsemax_loss (label : int , logit : float ) -> float :
@@ -77,33 +70,7 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
7770 Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
7871 Vlad Niculae. JMLR 2020. (Sec. 4.4)
7972 """
80- return sparse_plus (jnp .where (label , - logit , logit ))
81-
82-
83- def sparse_plus (x : float ) -> float :
84- r"""Sparse plus function.
85-
86- Computes the function:
87-
88- .. math::
89-
90- \mathrm{sparse\_plus}(x) = \begin{cases}
91- 0, & x \leq -1\\
92- \frac{1}{4}(x+1)^2, & -1 < x < 1 \\
93- x, & 1 \leq x
94- \end{cases}
95-
96- This is the twin function of the softplus activation ensuring a zero output
97- for inputs less than -1 and a linear output for inputs greater than 1,
98- while remaining smooth, convex, monotonic by an adequate definition between
99- -1 and 1.
100-
101- Args:
102- x: input (float)
103- Returns:
104- sparse_plus(x) as defined above
105- """
106- return jnp .where (x <= - 1.0 , 0.0 , jnp .where (x >= 1.0 , x , (x + 1.0 )** 2 / 4 ))
73+ return jax .nn .sparse_plus (jnp .where (label , - logit , logit ))
10774
10875
10976def sparse_sigmoid (x : float ) -> float :
@@ -144,8 +111,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
144111 References:
145112 https://en.wikipedia.org/wiki/Hinge_loss
146113 """
147- signed_label = 2.0 * label - 1.0
148- return jnp .maximum (0 , 1 - score * signed_label )
114+ return optax_losses .hinge_loss (score , 2.0 * label - 1.0 )
149115
150116
151117def binary_perceptron_loss (label : int , score : float ) -> float :
@@ -160,8 +126,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
160126 References:
161127 https://en.wikipedia.org/wiki/Perceptron
162128 """
163- signed_label = 2.0 * label - 1.0
164- return jnp .maximum (0 , - score * signed_label )
129+ return optax_losses .perceptron_loss (score , 2.0 * label - 1.0 )
165130
166131# Multiclass classification.
167132
@@ -175,13 +140,8 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
175140 Returns:
176141 loss value
177142 """
178- logits = jnp .asarray (logits )
179- # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
180- # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
181- # To avoid roundoff error, subtract target inside logsumexp.
182- # logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
183- logits = (logits - logits [label ]).at [label ].set (0.0 )
184- return logsumexp (logits )
143+ return optax_losses .softmax_cross_entropy_with_integer_labels (
144+ jnp .asarray (logits ), jnp .asarray (label ))
185145
186146
187147def multiclass_sparsemax_loss (label : int , scores : jnp .ndarray ) -> float :
@@ -272,5 +232,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
272232 """
273233
274234 def fy_loss (y_true , scores , * args , ** kwargs ):
275- return max_fun (scores , * args , ** kwargs ) - jnp .vdot (y_true , scores )
235+ return optax_losses .make_fenchel_young_loss (max_fun )(
236+ scores .ravel (), y_true .ravel (), * args , ** kwargs )
276237 return fy_loss
0 commit comments