Not the neurons we want, but the neurons we need
Activation-free neural layers that learn non-linearity through geometric operations
📚 Documentation · 📄 Read the Paper · 📝 Read the Blog · 🐛 Report Bug
NMN replaces traditional Linear + ReLU with a single geometric operation that learns non-linearity without activation functions:
# Traditional approach
y = relu(linear(x)) # dot product → activation
# NMN approach
y = yat(x) # geometric operation with built-in non-linearityThe Yat-Product (ⵟ) balances similarity and distance to create inherently non-linear transformations—no activations needed.
| Feature | Description |
|---|---|
| 🔥 Activation-Free | Learn complex non-linear relationships without ReLU, sigmoid, or tanh |
| 🌐 Multi-Framework | PyTorch, TensorFlow, Keras, Flax (Linen & NNX) |
| 🧮 Geometric Foundation | Based on distance-similarity tradeoff, not just correlations |
| ✅ Cross-Framework Consistency | Verified numerical equivalence across all frameworks |
| 🧠 Complete Layer Suite | Dense, Conv1D/2D/3D, ConvTranspose, Attention, RNN cells |
| ⚡ Production Ready | Comprehensive tests, CI/CD, high code coverage |
The core operation that powers NMN:
🔍 Geometric Interpretation (click to expand)
Rewriting in terms of norms and angles:
$$ ⵟ(\mathbf{w}, \mathbf{x}) = \frac{|\mathbf{w}|^2 |\mathbf{x}|^2 \cos^2\theta}{|\mathbf{w}|^2 - 2\langle\mathbf{w}, \mathbf{x}\rangle + |\mathbf{x}|^2 + \epsilon} $$
Output is maximized when:
- ✅ Vectors are aligned (small θ → large cos²θ)
- ✅ Vectors are close (small Euclidean distance)
- ✅ Vectors have large magnitude (amplifies the signal)
This creates a fundamentally different learning dynamic:
| Traditional Neuron | Yat Neuron |
|---|---|
| Measures correlation only | Balances similarity AND proximity |
| Requires activation for non-linearity | Non-linearity is intrinsic |
| Can fire for distant but aligned vectors | Penalizes distance between w and x |
The same principle applied to local patches:
Where W is the kernel and X is the input patch.
pip install nmn
# Framework-specific installations
pip install "nmn[torch]" # PyTorch
pip install "nmn[keras]" # Keras/TensorFlow
pip install "nmn[nnx]" # Flax NNX (JAX)
pip install "nmn[all]" # Everything|
PyTorch import torch
from nmn.torch.nmn import YatNMN
# Replace nn.Linear + activation
layer = YatNMN(
in_features=128,
out_features=64,
epsilon=1e-5
)
x = torch.randn(32, 128)
y = layer(x) # (32, 64) — non-linear output! |
Keras import keras
from nmn.keras.nmn import YatNMN
# Drop-in replacement for Dense
layer = YatNMN(
features=64,
epsilon=1e-5
)
x = keras.ops.zeros((32, 128))
y = layer(x) # (32, 64) |
|
Flax NNX from flax import nnx
from nmn.nnx.nmn import YatNMN
layer = YatNMN(
in_features=128,
out_features=64,
rngs=nnx.Rngs(0)
)
x = jax.numpy.zeros((32, 128))
y = layer(x) # (32, 64) |
TensorFlow import tensorflow as tf
from nmn.tf.nmn import YatNMN
layer = YatNMN(features=64)
x = tf.zeros((32, 128))
y = layer(x) # (32, 64) |
| Layer | PyTorch | TensorFlow | Keras | Flax NNX | Flax Linen |
|---|---|---|---|---|---|
| YatNMN (Dense) | ✅ | ✅ | ✅ | ✅ | ✅ |
| YatConv1D | ✅ | ✅ | ✅ | ✅ | ✅ |
| YatConv2D | ✅ | ✅ | ✅ | ✅ | ✅ |
| YatConv3D | ✅ | ✅ | ✅ | ✅ | ✅ |
| YatConvTranspose1D | ✅ | ✅ | ✅ | ✅ | ❌ |
| YatConvTranspose2D | ✅ | ✅ | ✅ | ✅ | ❌ |
| YatConvTranspose3D | ✅ | ✅ | ❌ | ✅ | ❌ |
| Layer | Status | Description |
|---|---|---|
| MultiHeadAttention | ✅ | Yat-based attention mechanism |
| YatSimpleCell | ✅ | Simple RNN cell |
| YatLSTMCell | ✅ | LSTM with Yat operations |
| YatGRUCell | ✅ | GRU with Yat operations |
| softermax | ✅ | Generalized softmax: |
| softer_sigmoid | ✅ | Smooth sigmoid variant |
| soft_tanh | ✅ | Smooth tanh variant |
| DropConnect | ✅ | Weight-level dropout regularization |
All implementations are verified to produce numerically equivalent outputs given identical inputs and weights:
┌─────────────────────────────────────────────────────────────┐
│ Cross-Framework Consistency Test │
├─────────────────────────────────────────────────────────────┤
│ Framework Pair │ Max Error │ Status │
├──────────────────────────┼──────────────┼───────────────────┤
│ PyTorch ↔ TensorFlow │ < 1e-6 │ ✅ PASS │
│ PyTorch ↔ Keras │ < 1e-6 │ ✅ PASS │
│ PyTorch ↔ Flax NNX │ < 1e-6 │ ✅ PASS │
│ PyTorch ↔ Flax Linen │ < 1e-6 │ ✅ PASS │
│ TensorFlow ↔ Keras │ < 1e-7 │ ✅ PASS │
│ Flax NNX ↔ Flax Linen │ < 1e-7 │ ✅ PASS │
└──────────────────────────┴──────────────┴───────────────────┘
This demonstrates the robustness of the geometric YAT formulation across different numerical backends.
See EXAMPLES.md for comprehensive usage guides including:
- Framework-specific quick starts (PyTorch, Keras, TensorFlow, Flax)
- Architecture examples (CNN, Transformer, RNN)
- Advanced features (DropConnect, custom squashers, attention)
Quick run:
python examples/torch/yat_cifar10.py # PyTorch CIFAR-10
python examples/keras/language_imdb.py # Keras sentiment
python examples/nnx/language/mingpt.py # JAX GPTComprehensive test suite with cross-framework validation:
# Install test dependencies
pip install "nmn[test]"
# Run all tests
pytest tests/ -v
# Run specific framework
pytest tests/test_torch/ -v
pytest tests/test_keras/ -v
pytest tests/test_nnx/ -v
# Run cross-framework consistency tests
pytest tests/integration/test_cross_framework_consistency.py -v
# With coverage
pytest tests/ --cov=nmn --cov-report=htmltests/
├── test_torch/ # PyTorch layer tests + math validation
├── test_keras/ # Keras layer tests
├── test_tf/ # TensorFlow layer tests
├── test_nnx/ # Flax NNX tests (attention, RNN, etc.)
├── test_linen/ # Flax Linen tests
└── integration/
├── test_cross_framework_consistency.py # Numerical equivalence
└── test_compatibility.py # API compatibility
Based on the research papers:
Deep Learning 2.0: Artificial Neurons that Matter — Reject Correlation, Embrace Orthogonality
Deep Learning 2.1: Mind and Cosmos — Towards Cosmos-Inspired Interpretable Neural Networks
Traditional neurons compute:
This has limitations:
- Correlation-based: Only measures alignment, ignores proximity
- Requires activation: Non-linearity is external
- Spurious activations: Can fire strongly for distant but aligned vectors
The Yat-Product addresses these by combining:
- Squared dot product (similarity) in the numerator
- Squared distance (proximity) in the denominator
- Epsilon for numerical stability
The result is a neuron that responds geometrically — activated when inputs are both similar AND close to weights.
We welcome contributions! See CONTRIBUTING.md for guidelines.
# Development setup
git clone https://github.com/mlnomadpy/nmn.git
cd nmn
pip install -e ".[dev,test]"
# Run tests
pytest tests/ -v
# Format code
black src/ tests/
isort src/ tests/Areas for contribution:
- 🐛 Bug fixes (open issues)
- ✨ New layer types (normalization, graph, etc.)
- 📚 Documentation and tutorials
- ⚡ Performance optimizations
- 🎨 Example applications
| Parameter | Type | Description |
|---|---|---|
in_features |
int | Input dimension (Dense) or channels (Conv) |
out_features |
int | Output dimension or filters |
kernel_size |
int | tuple | Convolution kernel size |
epsilon |
float | Numerical stability (default: 1e-5) |
use_bias |
bool | Include bias term (default: True) |
use_alpha |
bool | Learnable output scaling (default: True) |
# PyTorch
from nmn.torch.nmn import YatNMN
from nmn.torch.layers import YatConv2d, YatConvTranspose2d
# Keras / TensorFlow
from nmn.keras.nmn import YatNMN
from nmn.keras.conv import YatConv2D
# Flax NNX (most complete)
from nmn.nnx.nmn import YatNMN
from nmn.nnx.yatconv import YatConv
from nmn.nnx.yatattention import MultiHeadAttention
from nmn.nnx.rnn import YatLSTMCell📋 Full import reference → EXAMPLES.md
If you use NMN in your research, please cite:
@software{nmn2024,
author = {Bouhsine, Taha},
title = {NMN: Neural Matter Networks},
year = {2024},
url = {https://github.com/mlnomadpy/nmn}
}
@article{bouhsine2024dl2,
author = {Taha Bouhsine},
title = {Deep Learning 2.0: Artificial Neurons that Matter},
year = {2024}
}
@article{bouhsine2025dl21,
author = {Taha Bouhsine},
title = {Deep Learning 2.1: Mind and Cosmos},
year = {2025}
}
@article{bouhsine2025nomoredelulu,
author = {Taha Bouhsine},
title = {No More DeLuLu: A Kernel-Based Activation-Free Neural Networks},
year = {2025}
}- 🐛 Issues: GitHub Issues
- 💬 Discussions: GitHub Discussions
- 📧 Contact: taha@azetta.ai
AGPL-3.0 — Free for personal, academic, and commercial use with attribution.
If you modify and deploy on a network, you must share the source code.
For alternative licensing, contact us.
Built with ❤️ by azetta.ai
