Skip to content

Make JAX-compatible#8

Open
DavidMStraub wants to merge 3 commits intoflav-io:masterfrom
DavidMStraub:jax
Open

Make JAX-compatible#8
DavidMStraub wants to merge 3 commits intoflav-io:masterfrom
DavidMStraub:jax

Conversation

@DavidMStraub
Copy link
Copy Markdown
Collaborator

This makes ckmutil JAX-compatible.

JAX is an optional dependency and there is no API change when not using JAX (all JAXified functions live under ckmutil.jax).

Code duplication is minimal.

A short note is added to the Readme.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a JAX-compatible backend for ckmutil by refactoring the CKM and diagonalization logic into backend-agnostic implementations and exposing a new ckmutil.jax module, along with CI and documentation updates.

Changes:

  • Refactor CKM (ckmutil.ckm) and modified SVD (ckmutil.diag.msvd) to call backend-agnostic implementations (ckmutil/_ckm_impl.py, ckmutil/_diag_impl.py).
  • Introduce ckmutil.jax (plus optional dependency ckmutil[jax]) and add JAX-specific tests + CI job.
  • Update README with documentation link and JAX usage guidance.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
README.md Adds docs link and JAX support section with examples.
pyproject.toml Adds jax optional dependency extra.
ckmutil/test_jax.py New tests validating JAX parity with NumPy + grad/vmap behavior.
ckmutil/test_ckm.py Adjusts imports to avoid relying on NumPy symbols from ckmutil.ckm.
ckmutil/jax/init.py New JAX wrapper module exposing JAX-compatible functions via shared implementations.
ckmutil/diag.py Refactors msvd to use the backend-agnostic _msvd implementation.
ckmutil/ckm.py Refactors CKM functions to use backend-agnostic implementations.
ckmutil/_diag_impl.py New backend-agnostic implementation of msvd.
ckmutil/_ckm_impl.py New backend-agnostic implementations of CKM conversion/matrix functions.
.github/workflows/test.yml Adds a dedicated test-jax CI job running JAX tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

DavidMStraub and others added 2 commits March 12, 2026 12:31
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants