Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, the `XieLU`, and the `ReLU²` activation functions.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
Expand Down
1 change: 1 addition & 0 deletions docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Layers
Sequential
Sigmoid
SiLU
XieLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Expand Down
1 change: 1 addition & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Softsign,
Step,
Tanh,
XieLU,
celu,
elu,
gelu,
Expand Down
60 changes: 60 additions & 0 deletions python/mlx/nn/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@ def softplus(x):
return mx.logaddexp(x, 0)


@partial(mx.compile, shapeless=True)
def xielu(x, alpha_p, alpha_n, beta, eps):
r"""Applies the XieLU activation function.

This function uses parameterized positive and negative scaling with
exponential smoothing.

.. math::
\text{XieLU}(x) = \begin{cases}
\alpha_p * x^2 + \beta x & \text{if } x > 0 \\
(\exp(\min(x, \epsilon)) - 1 - x) * \alpha_n + \beta x & \text{if } x \leq 0
\end{cases}

Args:
alpha_p: Positive scaling parameter (softplus applied).
alpha_n: Negative scaling parameter (shifted by beta and softplus applied).
beta: Linear scaling factor.
eps: Clamping value for stability in the negative region.
"""
alpha_p = mx.logaddexp(alpha_p, 0)
alpha_n = beta + mx.logaddexp(alpha_n, 0)
return mx.where(
x > 0,
alpha_p * mx.square(x) + beta * x,
(mx.expm1(mx.minimum(x, eps)) - x) * alpha_n + beta * x,
)


@partial(mx.compile, shapeless=True)
def softsign(x):
r"""Applies the Softsign function.
Expand Down Expand Up @@ -541,6 +569,38 @@ def __call__(self, x: mx.array):
return prelu(x, self.weight)


class XieLU(Module):
r"""Applies the XieLU activation function.

See :func:`xielu` for the functional equivalent.

Args:
alpha_p_init (float): Initial value for the positive scaling parameter. Default: 0.8
alpha_n_init (float): Initial value for the negative scaling parameter. Default: 0.8
beta (float): Linear scaling factor. Default: 0.5
eps (float): Clamping value for stability in the negative region. Default: -1e-6
"""

def __init__(
self,
alpha_p_init=0.8,
alpha_n_init=0.8,
beta=0.5,
eps=-1e-6,
):
super().__init__()
alpha_p_tensor = mx.array(alpha_p_init)
alpha_n_tensor = mx.array(alpha_n_init - beta)
self.alpha_p = mx.log(mx.exp(alpha_p_tensor) - 1)
self.alpha_n = mx.log(mx.exp(alpha_n_tensor) - 1)

self.beta = mx.array(beta)
self.eps = mx.array(eps)

def __call__(self, x: mx.array) -> mx.array:
return xielu(x, self.alpha_p, self.alpha_n, self.beta, self.eps)


class GELU(Module):
r"""Applies the Gaussian Error Linear Units.

Expand Down