Skip to content

Commit ad608d5

Browse files
Initial commit
Signed-off-by: 周唤海 <albus.zhouhh@gmail.com>
0 parents  commit ad608d5

33 files changed

Lines changed: 8267 additions & 0 deletions

.github/workflows/ci.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Install uv
16+
uses: astral-sh/setup-uv@v4
17+
18+
- name: Set up Python
19+
run: uv python install 3.12
20+
21+
- name: Install dependencies
22+
run: uv sync
23+
24+
- name: Run tests
25+
run: uv run pytest -v

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.DS_Store
2+
3+
data/*.npz
4+
5+
*egg-info/
6+
*.pyc

AGENTS.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Agent Guidelines
2+
3+
Follow [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), with JAX patterns in this document taking precedence.
4+
5+
**Virtual environment**: Use `uv` for environment management. Run commands with `uv run`.
6+
7+
## Design Decisions
8+
9+
- **Match the paper.** Implementation should follow equations in `notes/Whitelam - 2026 - Generative Thermodynamic Computing.pdf`.
10+
- **DRY (Don't Repeat Yourself).** Consolidate duplicated implementations into a single source of truth.
11+
- **Let it crash.** Avoid defensive parameter checks; assume correct wiring and let errors surface.
12+
- **Uncertain correctness.** For uncertain behavior, refer to `notes/` or ask the user.
13+
- **Julia-style defaults.** Put defaults in function signature: `def foo(x, y=10):` not `def foo(x, y=None): y = y or 10`.
14+
- **No unnecessary intermediate variables.** Return directly: `return expr` not `result = expr; return result`.
15+
- **No unused imports/variables.** Remove any defined but unreferenced code.
16+
17+
## JAX Patterns
18+
19+
- **Use `jax.lax.scan`** over Python loops.
20+
- **Use `jax.vmap`** for batch operations (e.g., multiple denoising trajectories).
21+
- **No `jax.block_until_ready` in hot paths.** It breaks XLA fusion.
22+
- **Configurable dtype** via `config.dtype_name` (float32 or float64).
23+
24+
## Logging
25+
26+
- Use Python `logging` module, not `print()`.

README.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
A JAX replication of generative thermodynamic computing[^1][^2], which uses Langevin dynamics for visual generation. This repository provides MNIST digit synthesis as a minimal working example.
2+
3+
## How it works
4+
5+
Standard diffusion models use neural networks for denoising. Here, denoising is done by Langevin dynamics of a physical system with trained couplings — no neural network at inference time.
6+
7+
Training maximizes the probability that the system generates the reverse of noising trajectories, which is equivalent to minimizing heat dissipation. In hardware, this could be $\gt 10^{10}\times$ more efficient than digital computation.
8+
9+
**Noising** (image → noise):
10+
11+
![Noising trajectory](outputs/demo/fig1a_noising.png)
12+
13+
**Denoising & generation** (noise → image):
14+
15+
![Denoising and samples](outputs/demo/fig2.png)
16+
17+
## Installation
18+
19+
This project uses [uv](https://github.com/astral-sh/uv) for environment management.
20+
21+
```bash
22+
uv sync
23+
```
24+
25+
Verify the JAX backend:
26+
27+
```bash
28+
uv run python -c "import jax; print(jax.__version__); print(jax.default_backend())"
29+
```
30+
31+
## Data
32+
33+
Download MNIST:
34+
35+
```bash
36+
uv run download-mnist # saves to data/mnist.npz
37+
uv run download-mnist --out path/to/mnist.npz # custom path
38+
```
39+
40+
## Quick Start
41+
42+
Train a model and generate figures:
43+
44+
```bash
45+
uv run whitelam-2026 --mnist data/mnist.npz --out outputs/demo
46+
```
47+
48+
By default, the model trains on digits **(0, 1, 2)** with **512 hidden units**, matching the paper's setup.
49+
50+
## Configuration
51+
52+
Override any config field with `--set key=value` (repeatable). Values are parsed as JSON.
53+
54+
```bash
55+
# More training trajectories for better samples
56+
uv run whitelam-2026 --mnist data/mnist.npz --out outputs/demo \
57+
--set n_training_trajectories=1000
58+
59+
# Set random seed for reproducibility (default is 42)
60+
uv run whitelam-2026 --mnist data/mnist.npz --out outputs/demo --seed 123
61+
62+
# Higher DPI for publication-quality figures
63+
uv run whitelam-2026 --mnist data/mnist.npz --out outputs/demo --set fig_dpi=600
64+
```
65+
66+
### Scaling to All 10 Digits
67+
68+
The default model (512 hidden units) works well for 3 digit classes. To train on all digits (0-9), increase model capacity proportionally:
69+
70+
```bash
71+
uv run whitelam-2026 --mnist data/mnist.npz --out outputs/full_mnist \
72+
--set n_h=2048 \
73+
--set 'train_digits=[0,1,2,3,4,5,6,7,8,9]' \
74+
--set n_training_trajectories=1000
75+
```
76+
77+
**Scaling considerations:**
78+
- Training time scales with `n_h²` (hidden-hidden couplings)
79+
- Memory scales with `n_v × n_h` (visible-hidden couplings)
80+
81+
### Re-rendering Figures
82+
83+
To regenerate figures from saved parameters without retraining:
84+
85+
```bash
86+
uv run whitelam-2026 \
87+
--mnist data/mnist.npz \
88+
--params outputs/demo/params_learned.npz \
89+
--out outputs/demo_rerender
90+
```
91+
92+
## Outputs
93+
94+
Each run produces:
95+
96+
| File | Description |
97+
|------|-------------|
98+
| `fig1.png`, `fig2.png` | Composite figures |
99+
| `fig1a_noising.png` | Noising trajectory |
100+
| `fig1b_training_set.png` | Training digits |
101+
| `fig2a_denoising_trajectories.png` | Denoising trajectories |
102+
| `fig2b_samples.png` | Generated samples |
103+
| `fig2c_receptive_fields.png` | Learned hidden unit couplings |
104+
| `params_learned.npz` | Trained model parameters |
105+
| `config.json` | Configuration used |
106+
| `metrics.json` | Heat dissipation metrics |
107+
108+
## Development
109+
110+
Run tests:
111+
112+
```bash
113+
uv run pytest -v
114+
```
115+
116+
Use as a library:
117+
118+
```python
119+
from generative_langevin.whitelam_2026.config import Whitelam2026Config
120+
from generative_langevin.whitelam_2026.model import init_params
121+
from generative_langevin.whitelam_2026.train import train_many_noising_trajectories
122+
from generative_langevin.whitelam_2026.sample import run_denoising_trajectory
123+
```
124+
125+
## Reference
126+
127+
[^1]: S. Whitelam. **Generative Thermodynamic Computing**. *Physical Review Letters* 136(3):037101, 2026. https://doi.org/10.1103/kwyy-1xln
128+
129+
[^2]: https://github.com/swhitelam/generative_thermodynamic_computing
815 KB
Binary file not shown.

0 commit comments

Comments
 (0)