forked from femtomc/genjax
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpyproject.toml
More file actions
333 lines (283 loc) · 14.7 KB
/
pyproject.toml
File metadata and controls
333 lines (283 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
[project]
authors = [{ name = "McCoy Becker", email = "mccoyb@mit.edu" }]
name = "genjax"
requires-python = ">= 3.12"
version = "1.0.13"
dependencies = [
"penzai>=0.2.5,<0.3",
"beartype>=0.22.9,<0.23",
"jaxtyping>=0.3.9,<0.4",
"tensorflow-probability>=0.25.0,<0.26",
]
[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]
[tool.pixi.workspace]
channels = ["conda-forge", "pytorch", "nvidia"]
platforms = ["osx-arm64", "linux-64"]
[tool.pixi.dependencies]
jax = ">=0.7.2,<0.8"
matplotlib = ">=3.10.3,<4"
[tool.pixi.pypi-dependencies]
genjax = { path = ".", editable = true }
[tool.pixi.tasks]
# Clean all generated figures from examples
clean-figs = "find examples -name '*.pdf' -o -name '*.png' -o -name '*.jpg' -o -name '*.svg' | grep '/figs/' | xargs rm -f"
# =====================================================================
# PAPER FIGURE GENERATION - Centralized to genjax/figs/
# =====================================================================
# All figures for the POPL 2026 paper are generated to genjax/figs/
# Works on CPU (osx-arm64) and GPU (linux-64 with CUDA)
# Setup: Create central figs directory
paper-setup = "mkdir -p figs"
# Individual case study tasks (run within their environments)
paper-faircoin-gen = "pixi run -e faircoin python -m examples.faircoin.main --combined --num-obs 50 --num-samples 2000 --repeats 10"
paper-curvefit-gen = "pixi run -e curvefit python -m examples.curvefit.main paper"
paper-curvefit-custom = "pixi run -e curvefit python -m examples.curvefit.main"
paper-gol-gen = "pixi run -e gol python -m examples.gol.main --mode showcase"
paper-localization-gen = "pixi run -e localization python -m examples.localization.main paper --include-smc-comparison --n-particles 200 --n-steps 8 --timing-repeats 3 --n-rays 8 --output-dir figs"
paper-perfbench = "pixi run -e perfbench python examples/perfbench/main.py pipeline"
paper-perfbench-cuda = "pixi run -e perfbench-cuda python examples/perfbench/main.py pipeline --mode cuda"
paper-perfbench-clean = "python examples/perfbench/main.py clean"
paper-cone-gen = "python -m examples.cone.main fig2"
paper-cone-table = "python -m examples.cone.main table4"
# AIR estimator comparison (GenJAX-only)
air-compare = "python -m examples.air.main compare --small-config --num-examples 256 --epochs 2"
air-train = "python -m examples.air.main train --estimator enum --small-config --num-examples 256 --epochs 2"
# Generate/copy multi-MNIST data using Pyro's helper (writes examples/air/data/multi_mnist_uint8.npz)
air-fetch-data = "pixi run -e perfbench-pyro python -m examples.air.main fetch-data --output examples/air/data/multi_mnist_uint8.npz --cache-root /tmp/air-data"
# GPU AIR runs (use /dev/shm temp + disable Triton GEMM autotuning for constrained /tmp systems)
air-train-gpu = "TMPDIR=/dev/shm XLA_FLAGS='--xla_gpu_enable_triton_gemm=false --xla_gpu_autotune_level=0' pixi run -e cuda python -m examples.air.main train --dataset multi-mnist --data-path examples/air/data/multi_mnist_uint8.npz --num-examples 2048 --epochs 10 --batch-size 32 --eval-batch-size 128 --learning-rate 1e-4 --estimator enum"
air-compare-gpu = "TMPDIR=/dev/shm XLA_FLAGS='--xla_gpu_enable_triton_gemm=false --xla_gpu_autotune_level=0' pixi run -e cuda python -m examples.air.main compare --dataset multi-mnist --data-path examples/air/data/multi_mnist_uint8.npz --num-examples 2048 --epochs 10 --batch-size 32 --eval-batch-size 128 --learning-rate 1e-4"
paper-faircoin-gpu-gen = "pixi run -e faircoin-cuda python -m examples.faircoin.main --combined --num-obs 50 --num-samples 2000 --repeats 10"
paper-curvefit-gpu-gen = "pixi run -e curvefit-cuda python -m examples.curvefit.main paper"
paper-curvefit-gpu-custom = "pixi run -e curvefit-cuda python -m examples.curvefit.main"
paper-gol-gpu-gen = "pixi run -e gol-cuda python -m examples.gol.main --mode showcase"
paper-localization-gpu-gen = "pixi run -e localization-cuda python -m examples.localization.main paper --include-smc-comparison --n-particles 200 --n-steps 8 --timing-repeats 3 --n-rays 8 --output-dir figs"
# Main paper figures generation tasks
paper-figures = { depends-on = ["paper-setup", "paper-faircoin-gen", "paper-curvefit-gen", "paper-gol-gen", "paper-localization-gen"] }
paper-figures-gpu = { depends-on = ["paper-setup", "paper-faircoin-gpu-gen", "paper-curvefit-gpu-gen", "paper-gol-gpu-gen", "paper-localization-gpu-gen"] }
[tool.vulture]
make_whitelist = true
min_confidence = 80
paths = ["src"]
sort_by_size = true
[tool.pixi.feature.faircoin.dependencies]
matplotlib = "*"
seaborn = "*"
[tool.pixi.feature.faircoin.pypi-dependencies]
numpyro = "*"
[tool.pixi.feature.faircoin.tasks]
cmd = "faircoin"
# Beta-Bernoulli framework comparison (GenJAX vs NumPyro vs handcoded JAX)
faircoin = "python -m examples.faircoin.main --combined" # Main command (combined timing + posterior)
faircoin-timing = "python -m examples.faircoin.main" # Timing comparison only
faircoin-combined = "python -m examples.faircoin.main --combined" # Combined timing + posterior figure (recommended)
# Paper-specific task: generates faircoin_combined_posterior_and_timing_obs50_samples2000.pdf to figs/
faircoin-paper = "python -m examples.faircoin.main --combined --num-obs 50 --num-samples 2000"
[tool.pixi.feature.curvefit.dependencies]
matplotlib = "*"
numpy = "*"
pygments = "*"
seaborn = "*"
[tool.pixi.feature.curvefit.pypi-dependencies]
numpyro = "*"
funsor = "*"
[tool.pixi.feature.curvefit.tasks]
cmd = "curvefit"
# Paper-friendly defaults (only "paper" mode is supported)
curvefit = "python -m examples.curvefit.main paper" # Default POPL figure generation
curvefit-custom = "python -m examples.curvefit.main" # Pass flags after `--` to customise runs
[tool.pixi.feature.format.tasks]
# Code formatting and linting
format = "ruff format . && ruff check . --fix" # Format and lint Python code with ruff
format-md = "npx prettier --write '**/*.md'" # Format Markdown files with prettier
format-all = "ruff format . && ruff check . --fix && npx prettier --write '**/*.md'" # Format both Python and Markdown files
vulture = "vulture" # Find unused code
precommit-install = "pre-commit install" # Install pre-commit hooks
precommit-run = "pre-commit run --all-files" # Run pre-commit hooks
[tool.pixi.feature.test.tasks]
# Testing and coverage
test = "pytest tests/ -v --cov=src/genjax --cov-report=xml --cov-report=html --cov-report=term" # Run tests with coverage
test-parallel = "pytest tests/ -v -n auto --cov=src/genjax --cov-report=xml --cov-report=html --cov-report=term" # Run tests in parallel with auto-detected cores
test-fast = "pytest tests/ -v -n 4 -m 'not slow' --cov=src/genjax" # Run fast tests on 4 cores
coverage = "pytest tests/ -v --cov=src/genjax --cov-report=html --cov-report=term && echo 'Coverage report available at htmlcov/index.html'" # Generate coverage report
doctest = "xdoctest src/genjax --verbose=2" # Run doctests only
doctest-module = "xdoctest src/genjax/{module} --verbose=2" # Run doctests for specific module
test-all = "pytest tests/ -v --cov=src/genjax --cov-report=xml --cov-report=html --cov-report=term && xdoctest src/genjax --verbose=2" # Run tests + doctests
# Benchmarking tasks
benchmark = "pytest tests/ --benchmark-only -v" # Run only benchmark tests
benchmark-all = "pytest tests/ --benchmark-disable-gc --benchmark-sort=mean -v" # Run all tests with benchmarking
benchmark-compare = "pytest tests/ --benchmark-compare=0001 --benchmark-compare-fail=mean:10% -v" # Compare with previous benchmark results
benchmark-save = "pytest tests/ --benchmark-save=current --benchmark-disable-gc -v" # Save benchmark results
benchmark-slowest = "pytest tests/ --durations=20 --benchmark-disable -v" # Show 20 slowest tests without benchmarking
[tool.pixi.feature.gol.dependencies]
matplotlib = "*"
[tool.pixi.feature.gol.tasks]
cmd = "gol"
# Game of Life paper workflow (showcase figures)
gol = "python -m examples.gol.main --mode showcase"
gol-paper = "python -m examples.gol.main --mode showcase"
gol-paper-gpu = "python -m examples.gol.main --mode showcase"
[tool.pixi.feature.localization.dependencies]
matplotlib = "*"
seaborn = "*"
[tool.pixi.feature.localization.pypi-dependencies]
ptitprince = "*"
[tool.pixi.feature.localization.tasks]
cmd = "localization"
# Localization paper workflow
localization = "python -m examples.localization.main paper --include-smc-comparison"
localization-paper = "python -m examples.localization.main paper --include-smc-comparison"
[tool.pixi.feature.perfbench.dependencies]
numpy = "*"
matplotlib = "*"
seaborn = "*"
pandas = "*"
scipy = "*"
[tool.pixi.feature.perfbench.pypi-dependencies]
numpyro = "*"
tensorflow-probability = "*"
[tool.pixi.feature.perfbench-pyro.dependencies]
numpy = "*"
matplotlib = "*"
seaborn = "*"
pandas = "*"
scipy = "*"
[tool.pixi.feature.perfbench-pyro.target.linux-64.dependencies]
pytorch = { version = ">=2.0", channel = "pytorch" }
pytorch-cuda = { version = "12.*", channel = "pytorch" }
torchvision = { version = "*", channel = "pytorch" }
[tool.pixi.feature.perfbench-pyro.target.osx-arm64.dependencies]
pytorch = { version = ">=2.0", channel = "pytorch" }
torchvision = { version = "*", channel = "pytorch" }
[tool.pixi.feature.perfbench-pyro.pypi-dependencies]
pyro-ppl = "*"
[tool.pixi.feature.perfbench-torch.dependencies]
numpy = "*"
matplotlib = "*"
seaborn = "*"
pandas = "*"
scipy = "*"
[tool.pixi.feature.perfbench-torch.target.linux-64.dependencies]
pytorch = { version = ">=2.0", channel = "pytorch" }
pytorch-cuda = { version = "12.*", channel = "pytorch" }
[tool.pixi.feature.perfbench-torch.target.osx-arm64.dependencies]
pytorch = { version = ">=2.0", channel = "pytorch" }
[tool.pixi.feature.cuda.system-requirements]
cuda = "12"
[tool.pixi.feature.cuda.target.linux-64.dependencies]
# Python 3.14 currently lacks dm-tree wheels in our stack, forcing source builds
# that require newer CMake than many systems provide.
python = ">=3.12,<3.14"
# CUDA-enabled JAX for GPU acceleration on linux-64
jaxlib = { version = ">=0.7.2,<0.8", build = "*cuda12*" }
[tool.pixi.feature.cuda.tasks]
# CUDA GPU acceleration tasks
cuda-info = "python -c 'import jax; print(f\"JAX version: {jax.__version__}\"); print(f\"JAX devices: {jax.devices()}\"); print(f\"Default backend: {jax.default_backend()}\")'" # Check CUDA availability
cuda-test = "pixi run test -k 'not slow'" # Run tests with CUDA backend
[tool.pixi.environments]
default = { solve-group = "default" }
format = { features = ["format"], solve-group = "default" }
test = { features = ["test"], solve-group = "default" }
cuda = { features = ["cuda"], solve-group = "cuda" }
faircoin = { features = ["faircoin"], solve-group = "default" }
faircoin-cuda = { features = ["faircoin", "cuda"], solve-group = "cuda" }
curvefit = { features = ["curvefit"], solve-group = "default" }
curvefit-cuda = { features = ["curvefit", "cuda"], solve-group = "cuda" }
gol = { features = ["gol"], solve-group = "default" }
gol-cuda = { features = ["gol", "cuda"], solve-group = "cuda" }
localization = { features = ["localization"], solve-group = "default" }
localization-cuda = { features = ["localization", "cuda"], solve-group = "cuda" }
perfbench = { features = ["perfbench"], solve-group = "perfbench" }
perfbench-cuda = { features = ["perfbench", "cuda"], solve-group = "perfbench-cuda" }
perfbench-pyro = { features = ["perfbench-pyro"], solve-group = "perfbench-pyro" }
perfbench-torch = { features = ["perfbench-torch"], solve-group = "perfbench-torch" }
[tool.pixi.feature.format.dependencies]
nodejs = "*"
[dependency-groups]
format = ["ruff>=0.15.2,<0.16", "vulture>=2.14,<3", "pre-commit>=4.0,<5"]
test = [
"pytest>=9.0,<10",
"pytest-cov>=7.0,<8",
"coverage>=7.0,<8",
"xdoctest>=1.1.0,<2",
"pytest-xdist>=3.0,<4",
"pytest-benchmark>=5.0,<6",
]
[tool.coverage.run]
source = ["src"]
omit = ["*/tests/*", "*/examples/*"]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]
[tool.pytest.ini_options]
minversion = "8.0"
addopts = [
"-ra", # Show short test summary for all results
"--strict-markers", # Require all markers to be defined
"--strict-config", # Strict configuration parsing
"--cov=src/genjax", # Coverage for source code
"--cov-report=term-missing", # Show missing lines in terminal
"--cov-report=html", # Generate HTML coverage report
"--cov-report=xml", # Generate XML coverage for CI
]
testpaths = ["tests"]
pythonpath = ["."]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (taking >5 seconds)",
"fast: marks tests as fast (taking <1 second)",
"integration: marks tests as integration tests (cross-component)",
"unit: marks tests as unit tests (single component)",
"regression: marks tests as regression tests (bug prevention)",
"adev: marks tests for ADEV gradient estimators",
"smc: marks tests for Sequential Monte Carlo",
"mcmc: marks tests for Markov Chain Monte Carlo",
"vi: marks tests for Variational Inference",
"hmm: marks tests for Hidden Markov Models",
"core: marks tests for core GenJAX functionality",
"pjax: marks tests for PJAX (Probabilistic JAX) functionality",
"distributions: marks tests for probability distributions",
"tfp: marks tests requiring TensorFlow Probability",
"requires_gpu: marks tests that need GPU acceleration",
"benchmark: marks tests that should be benchmarked",
]
filterwarnings = [
"ignore::DeprecationWarning:jax.*",
"ignore::DeprecationWarning:tensorflow_probability.*",
"error::UserWarning", # Turn UserWarnings into errors to catch issues
]
[tool.pytest-benchmark]
# Configuration for pytest-benchmark
min_rounds = 3 # Minimum number of benchmark rounds
max_time = 10.0 # Maximum time per benchmark (seconds)
min_time = 0.01 # Minimum time per round (seconds)
timer = "time.perf_counter" # High-resolution timer
disable_gc = true # Disable garbage collection during benchmarks
sort = "mean" # Sort results by mean time
columns = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "ops", "rounds"]
histogram = true # Generate histogram data
save = ".benchmarks/benchmarks.json" # Save results to file
save_data = true # Save benchmark data
autosave = true # Automatically save results
[tool.xdoctest]
# Configure xdoctest for running doctests
modname = "genjax"
command = "list"
verbose = 2
durations = 10
style = "google"
options = "+ELLIPSIS"