From 08676049056fd069b3b2f2f6220d78a2508ef1ae Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Tue, 27 Jan 2026 17:22:36 +0500 Subject: [PATCH 01/10] CI(gh-2126): transition to uv for dependency management and testing --- .github/workflows/ci.yml | 135 +++++++++++++++++---------------------- 1 file changed, 58 insertions(+), 77 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 291be9f28..e9c825f3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,54 +13,47 @@ env: jobs: prek: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - name: prek check uses: j178/prek-action@v1 with: extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy - lint: - runs-on: ubuntu-latest strategy: matrix: python-version: ["3.11", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y pandoc gsfonts - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install '.[doc,test]' - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -r docs/requirements.txt - pip freeze + uv pip install --upgrade jaxlib jax + uv pip install --upgrade '.[doc,test]' + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -r docs/requirements.txt + uv pip freeze - name: Lint with mypy and ruff run: | - make lint + uv run make lint - name: Build documentation run: | - make docs + uv run make docs - name: Test documentation run: | - make doctest - python -m doctest -v README.md - + uv run make doctest + uv run python -m doctest -v README.md test-modeling: - runs-on: ubuntu-latest needs: [lint, prek] strategy: @@ -68,39 +61,37 @@ jobs: python-version: ["3.11", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | sudo apt install -y graphviz - python -m pip install --upgrade pip # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,test]' + uv pip freeze - name: Test with pytest run: | - CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ + CI=1 uv run pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum" + JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum" - name: Test tracer leak if: matrix.python-version == '3.13' env: JAX_CHECK_TRACER_LEAKS: 1 run: | - pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit - pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke - pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run - pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths - pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - + uv run pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit + uv run pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke + uv run pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run + uv run pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths + uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -109,9 +100,7 @@ jobs: parallel: true flag-name: test-modeling - test-inference: - runs-on: ubuntu-latest needs: [lint, prek] strategy: @@ -119,41 +108,39 @@ jobs: python-version: ["3.11", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze + uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,test]' + uv pip freeze - name: Test with pytest run: | - pytest -vs --durations=20 test/infer/test_mcmc.py - pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py - pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py + uv run pytest -vs --durations=20 test/infer/test_mcmc.py + uv run pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py + uv run pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py - name: Test x64 run: | - JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 + JAX_ENABLE_X64=1 uv run pytest -vs test/infer/test_mcmc.py -k x64 - name: Test chains run: | - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/contrib/stochastic_support/test_dcc.py + XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs test/infer/test_hmc_gibbs.py -k "chain" - name: Test custom prng run: | - JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py + JAX_ENABLE_CUSTOM_PRNG=1 uv run pytest -vs test/infer/test_mcmc.py - name: Test nested sampling run: | - JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py + JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -162,9 +149,7 @@ jobs: parallel: true flag-name: test-inference - examples: - runs-on: ubuntu-latest needs: [lint, prek] strategy: @@ -172,22 +157,21 @@ jobs: python-version: ["3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - uses: actions/checkout@v6 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: + enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,examples,test]' - pip freeze + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,examples,test]' + uv pip freeze - name: Test with pytest run: | - CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example + CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example - name: Coveralls if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' uses: coverallsapp/github-action@v2 @@ -196,9 +180,7 @@ jobs: parallel: true flag-name: examples - finish: - needs: [test-modeling, test-inference, examples] runs-on: ubuntu-latest steps: @@ -208,4 +190,3 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true carryforward: "test-modeling,test-inference,examples" - From 5453f405f09c2292e950e35d62c0198c97966545 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Tue, 27 Jan 2026 17:38:44 +0500 Subject: [PATCH 02/10] fix: install packages in system --- .github/workflows/ci.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e9c825f3f..e6dc21ffb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,10 +37,10 @@ jobs: - name: Install dependencies run: | sudo apt install -y pandoc gsfonts - uv pip install --upgrade jaxlib jax - uv pip install --upgrade '.[doc,test]' - uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --upgrade -r docs/requirements.txt + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade '.[doc,test]' + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -r docs/requirements.txt uv pip freeze - name: Lint with mypy and ruff run: | @@ -71,10 +71,10 @@ jobs: run: | sudo apt install -y graphviz # Keep track of pyro-api master branch - uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip - uv pip install --upgrade jaxlib jax - uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --upgrade -e '.[dev,test]' + uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,test]' uv pip freeze - name: Test with pytest run: | @@ -116,10 +116,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip - uv pip install --upgrade jaxlib jax - uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --upgrade -e '.[dev,test]' + uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,test]' uv pip freeze - name: Test with pytest run: | @@ -165,9 +165,9 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - uv pip install --upgrade jaxlib jax - uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --upgrade -e '.[dev,examples,test]' + uv pip install --system --upgrade jaxlib jax + uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --system --upgrade -e '.[dev,examples,test]' uv pip freeze - name: Test with pytest run: | From 551d801448bcbee7fd2c6d2030c30b35d9880ae5 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Tue, 27 Jan 2026 17:45:39 +0500 Subject: [PATCH 03/10] fix: update path --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e6dc21ffb..8ff31b38c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | @@ -65,6 +66,7 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v7 with: + update-path: true enable-cache: true python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -113,6 +115,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | @@ -162,6 +165,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: enable-cache: true + update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | From 2fdb2b70f2458146887700f0c164164aae7b30b1 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Tue, 27 Jan 2026 17:48:13 +0500 Subject: [PATCH 04/10] fix: specify uv python version in enviornment variable --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ff31b38c..04588e191 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,8 @@ jobs: strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v6 @@ -60,6 +62,8 @@ jobs: strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v6 @@ -108,6 +112,8 @@ jobs: strategy: matrix: python-version: ["3.11", "3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v6 @@ -158,6 +164,8 @@ jobs: strategy: matrix: python-version: ["3.13"] + env: + UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v6 From c3ce49fddde380b18fc9af270b70d0c3fe01fc37 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Thu, 12 Feb 2026 00:59:48 +0500 Subject: [PATCH 05/10] chore: `setup.py` -> `pyproject.toml` Co-authored-by: nstarman --- pyproject.toml | 97 +++++++++++++++++++++++++++++++++++++++++++ setup.py | 109 ------------------------------------------------- 2 files changed, 97 insertions(+), 109 deletions(-) delete mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml index 933434e59..008c11921 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,100 @@ +[project] +name = "numpyro" +authors = [{ name = "Uber AI Labs", email = "fehiepsi@gmail.com" }] +dynamic = ["version"] +description = "Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU." +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.9" +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", +] +keywords = ["probabilistic", "machine learning", "bayesian", "statistics"] +dependencies = [ + "jax>=0.7.0", + "jaxlib>=0.7.0", + "multipledispatch", + "numpy", + "tqdm", +] + +[tool.setuptools.dynamic] +version = { attr = "numpyro.version.__version__" } + +[project.optional-dependencies] +cpu = ["jax[cpu]>=0.7.0"] +cuda12 = ["jax[cuda12]>=0.7.0"] +cuda13 = ["jax[cuda13]>=0.7.0"] +tpu = ["jax[tpu]>=0.7.0"] +dev = [ + "dm-haiku>=0.0.14", + "equinox", + "flax", + "funsor>=0.4.1", + "graphviz", + "jaxns>=2.6.3,<=2.6.9", + "matplotlib", + "optax>=0.0.6", + "pylab-sdk", # jaxns dependency + "pytest-cov", + "pyyaml", # flax dependency + "requests", # pylab dependency + "tfp-nightly", +] +test = [ + "importlib-metadata<5.0", + "mypy>=1.13", + "pyro-api>=0.1.1", + "pytest>=4.1", + "ruff>=0.1.8", + "scikit-learn", + "scipy>=1.9", + "ty>=0.0.4", +] +doc = [ + "ipython", # sphinx needs this to render codes + "nbsphinx>=0.8.9", + "readthedocs-sphinx-search>=0.3.2", + "sphinx_rtd_theme", + "sphinx-gallery", + "sphinx>=5", +] +examples = [ + "arviz", + "jupyter", + "matplotlib", + "pandas", + "scikit-learn", + "seaborn", + "wordcloud", +] + +[project.urls] +Changelog = "https://github.com/pyro-ppl/numpyro/blob/main/CHANGELOG.md" +Discussion = "https://github.com/pyro-ppl/numpyro/discussions" +Homepage = "https://github.com/pyro-ppl/numpyro" +Issues = "https://github.com/pyro-ppl/numpyro/issues" + +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +# NOTE: this can be simplified using src-layout +[tool.setuptools.packages.find] +include = ["numpyro*"] + +[tool.setuptools] +include-package-data = true + [tool.ruff] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/setup.py b/setup.py deleted file mode 100644 index f0bef9f5a..000000000 --- a/setup.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import absolute_import, division, print_function - -import os -import sys - -from setuptools import find_packages, setup - -PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.7.0" -_jaxlib_version_constraints = ">=0.7.0" - -# Find version -for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): - if line.startswith("__version__ = "): - version = line.strip().split()[2][1:-1] - -# READ README.md for long description on PyPi. -try: - long_description = open("README.md", encoding="utf-8").read() -except Exception as e: - sys.stderr.write("Failed to read README.md:\n {}\n".format(e)) - sys.stderr.flush() - long_description = "" - -setup( - name="numpyro", - version=version, - description="Pyro PPL on NumPy", - packages=find_packages(include=["numpyro", "numpyro.*"]), - url="https://github.com/pyro-ppl/numpyro", - author="Uber AI Labs", - install_requires=[ - f"jax{_jax_version_constraints}", - f"jaxlib{_jaxlib_version_constraints}", - "multipledispatch", - "numpy", - "tqdm", - ], - extras_require={ - "doc": [ - "ipython", # sphinx needs this to render codes - "nbsphinx>=0.8.9", - "readthedocs-sphinx-search>=0.3.2", - "sphinx>=5", - "sphinx_rtd_theme", - "sphinx-gallery", - ], - "test": [ - "importlib-metadata<5.0", - "ruff>=0.1.8", - "mypy>=1.13", - "pytest>=4.1", - "pyro-api>=0.1.1", - "scikit-learn", - "scipy>=1.9", - "ty>=0.0.4", - ], - "dev": [ - "dm-haiku>=0.0.14", - "equinox", - "flax", - "funsor>=0.4.1", - "graphviz", - "jaxns>=2.6.3,<=2.6.9", - "matplotlib", - "optax>=0.0.6", - "pylab-sdk", # jaxns dependency - "pytest-cov", - "pyyaml", # flax dependency - "requests", # pylab dependency - "tfp-nightly", - ], - "examples": [ - "arviz", - "jupyter", - "matplotlib", - "pandas", - "seaborn", - "scikit-learn", - "wordcloud", - ], - "cpu": f"jax[cpu]{_jax_version_constraints}", - # TPU and CUDA installations, currently require to add package repository URL, i.e., - # pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html - "tpu": f"jax[tpu]{_jax_version_constraints}", - "cuda12": f"jax[cuda12]{_jax_version_constraints}", - "cuda13": f"jax[cuda13]{_jax_version_constraints}", - }, - python_requires=">=3.9", - long_description=long_description, - long_description_content_type="text/markdown", - keywords="probabilistic machine learning bayesian statistics", - license="Apache License 2.0", - classifiers=[ - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: 3.14", - ], -) From fde3b4115c36e7be41db17dbf077c93b5c620514 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sun, 15 Feb 2026 12:58:51 +0500 Subject: [PATCH 06/10] Revert "chore: `setup.py` -> `pyproject.toml`" This reverts commit c3ce49fddde380b18fc9af270b70d0c3fe01fc37. --- pyproject.toml | 97 ------------------------------------------- setup.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 97 deletions(-) create mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml index 008c11921..933434e59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,100 +1,3 @@ -[project] -name = "numpyro" -authors = [{ name = "Uber AI Labs", email = "fehiepsi@gmail.com" }] -dynamic = ["version"] -description = "Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU." -readme = "README.md" -license = { file = "LICENSE" } -requires-python = ">=3.9" -classifiers = [ - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: 3.14", -] -keywords = ["probabilistic", "machine learning", "bayesian", "statistics"] -dependencies = [ - "jax>=0.7.0", - "jaxlib>=0.7.0", - "multipledispatch", - "numpy", - "tqdm", -] - -[tool.setuptools.dynamic] -version = { attr = "numpyro.version.__version__" } - -[project.optional-dependencies] -cpu = ["jax[cpu]>=0.7.0"] -cuda12 = ["jax[cuda12]>=0.7.0"] -cuda13 = ["jax[cuda13]>=0.7.0"] -tpu = ["jax[tpu]>=0.7.0"] -dev = [ - "dm-haiku>=0.0.14", - "equinox", - "flax", - "funsor>=0.4.1", - "graphviz", - "jaxns>=2.6.3,<=2.6.9", - "matplotlib", - "optax>=0.0.6", - "pylab-sdk", # jaxns dependency - "pytest-cov", - "pyyaml", # flax dependency - "requests", # pylab dependency - "tfp-nightly", -] -test = [ - "importlib-metadata<5.0", - "mypy>=1.13", - "pyro-api>=0.1.1", - "pytest>=4.1", - "ruff>=0.1.8", - "scikit-learn", - "scipy>=1.9", - "ty>=0.0.4", -] -doc = [ - "ipython", # sphinx needs this to render codes - "nbsphinx>=0.8.9", - "readthedocs-sphinx-search>=0.3.2", - "sphinx_rtd_theme", - "sphinx-gallery", - "sphinx>=5", -] -examples = [ - "arviz", - "jupyter", - "matplotlib", - "pandas", - "scikit-learn", - "seaborn", - "wordcloud", -] - -[project.urls] -Changelog = "https://github.com/pyro-ppl/numpyro/blob/main/CHANGELOG.md" -Discussion = "https://github.com/pyro-ppl/numpyro/discussions" -Homepage = "https://github.com/pyro-ppl/numpyro" -Issues = "https://github.com/pyro-ppl/numpyro/issues" - -[build-system] -requires = ["setuptools>=61", "wheel"] -build-backend = "setuptools.build_meta" - -# NOTE: this can be simplified using src-layout -[tool.setuptools.packages.find] -include = ["numpyro*"] - -[tool.setuptools] -include-package-data = true - [tool.ruff] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..f0bef9f5a --- /dev/null +++ b/setup.py @@ -0,0 +1,109 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import absolute_import, division, print_function + +import os +import sys + +from setuptools import find_packages, setup + +PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) +_jax_version_constraints = ">=0.7.0" +_jaxlib_version_constraints = ">=0.7.0" + +# Find version +for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): + if line.startswith("__version__ = "): + version = line.strip().split()[2][1:-1] + +# READ README.md for long description on PyPi. +try: + long_description = open("README.md", encoding="utf-8").read() +except Exception as e: + sys.stderr.write("Failed to read README.md:\n {}\n".format(e)) + sys.stderr.flush() + long_description = "" + +setup( + name="numpyro", + version=version, + description="Pyro PPL on NumPy", + packages=find_packages(include=["numpyro", "numpyro.*"]), + url="https://github.com/pyro-ppl/numpyro", + author="Uber AI Labs", + install_requires=[ + f"jax{_jax_version_constraints}", + f"jaxlib{_jaxlib_version_constraints}", + "multipledispatch", + "numpy", + "tqdm", + ], + extras_require={ + "doc": [ + "ipython", # sphinx needs this to render codes + "nbsphinx>=0.8.9", + "readthedocs-sphinx-search>=0.3.2", + "sphinx>=5", + "sphinx_rtd_theme", + "sphinx-gallery", + ], + "test": [ + "importlib-metadata<5.0", + "ruff>=0.1.8", + "mypy>=1.13", + "pytest>=4.1", + "pyro-api>=0.1.1", + "scikit-learn", + "scipy>=1.9", + "ty>=0.0.4", + ], + "dev": [ + "dm-haiku>=0.0.14", + "equinox", + "flax", + "funsor>=0.4.1", + "graphviz", + "jaxns>=2.6.3,<=2.6.9", + "matplotlib", + "optax>=0.0.6", + "pylab-sdk", # jaxns dependency + "pytest-cov", + "pyyaml", # flax dependency + "requests", # pylab dependency + "tfp-nightly", + ], + "examples": [ + "arviz", + "jupyter", + "matplotlib", + "pandas", + "seaborn", + "scikit-learn", + "wordcloud", + ], + "cpu": f"jax[cpu]{_jax_version_constraints}", + # TPU and CUDA installations, currently require to add package repository URL, i.e., + # pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html + "tpu": f"jax[tpu]{_jax_version_constraints}", + "cuda12": f"jax[cuda12]{_jax_version_constraints}", + "cuda13": f"jax[cuda13]{_jax_version_constraints}", + }, + python_requires=">=3.9", + long_description=long_description, + long_description_content_type="text/markdown", + keywords="probabilistic machine learning bayesian statistics", + license="Apache License 2.0", + classifiers=[ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + ], +) From 5f0d34a7dcc5a2d24fec5f84b9d02b47bdc343a4 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 16 Feb 2026 11:35:00 +0500 Subject: [PATCH 07/10] chore: remove `UV_PYTHON` from env variable --- .github/workflows/ci.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04588e191..eff544d40 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,8 +26,6 @@ jobs: strategy: matrix: python-version: ["3.11", "3.13"] - env: - UV_PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v6 @@ -35,7 +33,6 @@ jobs: uses: astral-sh/setup-uv@v7 with: enable-cache: true - update-path: true python-version: ${{ matrix.python-version }} - name: Install dependencies run: | From d8c3c89f635582ba1d4cc84c966e9fcf0b84ef70 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 16 Feb 2026 11:44:52 +0500 Subject: [PATCH 08/10] chore: explicit python installation --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eff544d40..5613c3eb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,6 +120,8 @@ jobs: enable-cache: true update-path: true python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: | uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip From 8c1b1b07cc2ae54c1bf025910f20b8e484f65f8a Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 16 Feb 2026 11:46:06 +0500 Subject: [PATCH 09/10] fix: remove `--system` flag --- .github/workflows/ci.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5613c3eb0..4a7d0f885 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,10 +37,10 @@ jobs: - name: Install dependencies run: | sudo apt install -y pandoc gsfonts - uv pip install --system --upgrade jaxlib jax - uv pip install --system --upgrade '.[doc,test]' - uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --system --upgrade -r docs/requirements.txt + uv pip install --upgrade jaxlib jax + uv pip install --upgrade '.[doc,test]' + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -r docs/requirements.txt uv pip freeze - name: Lint with mypy and ruff run: | @@ -74,10 +74,10 @@ jobs: run: | sudo apt install -y graphviz # Keep track of pyro-api master branch - uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip - uv pip install --system --upgrade jaxlib jax - uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --system --upgrade -e '.[dev,test]' + uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,test]' uv pip freeze - name: Test with pytest run: | @@ -124,10 +124,10 @@ jobs: run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: | - uv pip install --system --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip - uv pip install --system --upgrade jaxlib jax - uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --system --upgrade -e '.[dev,test]' + uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,test]' uv pip freeze - name: Test with pytest run: | @@ -176,9 +176,9 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - uv pip install --system --upgrade jaxlib jax - uv pip install --system --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip - uv pip install --system --upgrade -e '.[dev,examples,test]' + uv pip install --upgrade jaxlib jax + uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip + uv pip install --upgrade -e '.[dev,examples,test]' uv pip freeze - name: Test with pytest run: | From 3d0ab04539ea32cd2128c5cdce26242cd39e49e1 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Wed, 25 Feb 2026 22:51:27 +0500 Subject: [PATCH 10/10] fix: avoid installing pyro-api https://github.com/pyro-ppl/pyro-api/pull/26 --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a7d0f885..1d13e4881 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,7 +74,8 @@ jobs: run: | sudo apt install -y graphviz # Keep track of pyro-api master branch - uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + # See: https://github.com/pyro-ppl/pyro-api/pull/26 + # uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip uv pip install --upgrade jaxlib jax uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip uv pip install --upgrade -e '.[dev,test]' @@ -124,7 +125,8 @@ jobs: run: uv python install ${{ matrix.python-version }} - name: Install dependencies run: | - uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip + # See: https://github.com/pyro-ppl/pyro-api/pull/26 + # uv pip install --upgrade https://github.com/pyro-ppl/pyro-api/archive/master.zip uv pip install --upgrade jaxlib jax uv pip install --upgrade https://github.com/pyro-ppl/funsor/archive/master.zip uv pip install --upgrade -e '.[dev,test]'