Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"\n",
"To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n",
"\n",
"In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n"
"In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.importers.petab.PetabImporter](https://amici.readthedocs.io/en/latest/generated/amici.importers.petab.html#amici.importers.petab.PetabImporter), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n"
]
},
{
Expand Down Expand Up @@ -65,7 +65,7 @@
"source": [
"## Simulation\n",
"\n",
"We can now run efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)."
"We can now run efficient simulation using [amici.sim.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.run_simulations)."
]
},
{
Expand All @@ -75,7 +75,7 @@
"metadata": {},
"outputs": [],
"source": [
"from amici.jax import run_simulations\n",
"from amici.sim.jax import run_simulations\n",
"\n",
"# Run simulations and compute the log-likelihood\n",
"llh, results = run_simulations(jax_problem)"
Expand Down Expand Up @@ -249,7 +249,7 @@
"source": [
"The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n",
"\n",
"However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one."
"However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one."
]
},
{
Expand Down Expand Up @@ -386,7 +386,7 @@
"import diffrax\n",
"import jax.numpy as jnp\n",
"import optimistix\n",
"from amici.jax import ReturnValue\n",
"from amici.sim.jax import ReturnValue\n",
"\n",
"# Define the simulation condition\n",
"experiment_condition = \"_petab_experiment_condition___default__\"\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/python_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ AMICI Python API
amici.importers.petab
amici.importers.petab.v1
amici.importers.utils
amici.jax
amici.sim.jax
amici.sim.sundials
amici.sim.sundials.plotting
amici.sim.sundials.gradient_check
Expand Down
3 changes: 2 additions & 1 deletion python/sdist/amici/_symbolic/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,7 +2499,8 @@ def _process_hybridization(self, hybridization: dict) -> None:
orig_obs = tuple([s.get_sym() for s in self._observables])
for net_id, net in hybridization.items():
if net["static"]:
continue # do not integrate into ODEs, handle in amici.jax.petab
# do not integrate into ODEs, handle in amici.sim.jax.petab
continue
inputs = [
comp
for comp in self._components
Expand Down
16 changes: 16 additions & 0 deletions python/sdist/amici/exporters/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Code generation for JAX models for simulation with diffrax solvers.

This module provides an interface to generate and use AMICI models with JAX.
Please note that this module is experimental, the API may substantially change
in the future. Use at your own risk and do not expect backward compatibility.
"""

from .nn import Flatten, cat, generate_equinox, tanhshrink

__all__ = [
"Flatten",
"generate_equinox",
"tanhshrink",
"cat",
]
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# ruff: noqa: F401, F821, F841
from pathlib import Path

import equinox as eqx
import equinox as eqx # noqa: F401
import jax.numpy as jnp
import jax.random as jr
import jaxtyping as jt
from interpax import interp1d
from jax.numpy import inf as oo
from jax.numpy import nan as nan

from amici import _module_from_path
from amici.jax.model import JAXModel, safe_div, safe_log
import jax.random as jr # noqa: F401
import jaxtyping as jt # noqa: F401
from interpax import interp1d # noqa: F401
from jax.numpy import inf as oo # noqa: F401
from jax.numpy import nan as nan # noqa: F401

from amici import _module_from_path # noqa: F401
from amici.sim.jax.model import JAXModel, safe_div, safe_log # noqa: F401

TPL_NET_IMPORTS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import jax.numpy as jnp

from amici import amiciModulePath

from ..exporters.template import apply_template
from amici.exporters.template import apply_template


class Flatten(eqx.Module):
Expand Down Expand Up @@ -185,7 +184,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
layer_map = {
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
"Flatten": "amici.jax.Flatten",
"Flatten": "amici.export.jax.Flatten",
}

# mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
Expand Down Expand Up @@ -321,9 +320,9 @@ def _process_activation_call(node: "Node") -> str: # noqa: F821
"hardtanh": "jax.nn.hard_tanh",
"hardsigmoid": "jax.nn.hard_sigmoid",
"hardswish": "jax.nn.hard_swish",
"tanhshrink": "amici.jax.tanhshrink",
"tanhshrink": "amici.export.jax.tanhshrink",
"softsign": "jax.nn.soft_sign",
"cat": "amici.jax.cat",
"cat": "amici.export.jax.cat",
}

# Validate hardtanh parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# ruff: noqa: F401, F821, F841
import equinox as eqx
import jax
import jax.nn
import jax.numpy as jnp
import jax.random as jr

import amici.jax.nn


class TPL_MODEL_ID(eqx.Module):
layers: dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@

import sympy as sp

from amici import (
amiciModulePath,
)
from amici._symbolic.de_model import DEModel
from amici._symbolic.sympy_utils import (
_monkeypatch_sympy,
)
from amici.exporters.jax.nn import generate_equinox
from amici.exporters.sundials.de_export import is_valid_identifier
from amici.exporters.template import apply_template
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str
from amici.jax.model import JAXModel
from amici.jax.nn import generate_equinox
from amici.logging import get_logger, log_execution_time, set_log_level
from amici.sim.jax.model import JAXModel

from .jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str

#: python log manager
logger = get_logger(__name__, logging.ERROR)
Expand Down Expand Up @@ -303,7 +301,7 @@ def _generate_jax_code(self) -> None:
}

apply_template(
Path(amiciModulePath) / "jax" / "jax.template.py",
Path(__file__).parent / "jax.template.py",
self.model_path / "__init__.py",
tpl_data,
)
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/exporters/sundials/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Code generation for models for simulation with SUNDIALS solvers."""
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/petab/_petab_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from amici import get_model_dir
from amici._symbolic import DEModel, Event
from amici.importers.utils import MeasurementChannel, amici_time_symbol
from amici.jax.petab import JAXProblem
from amici.logging import get_logger
from amici.sim.jax.petab import JAXProblem

from .v1.sbml_import import _add_global_parameter

Expand Down
6 changes: 3 additions & 3 deletions python/sdist/amici/importers/petab/v1/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def import_petab_problem(
non_estimated_parameters_as_constants=True,
jax=False,
**kwargs,
) -> "amici.sim.sundials.Model | amici.jax.JAXProblem":
) -> "amici.sim.sundials.Model | amici.sim.jax.JAXProblem":
"""
Create an AMICI model for a PEtab problem.

Expand Down Expand Up @@ -75,7 +75,7 @@ def import_petab_problem(

:param jax:
Whether to create a JAX-based problem. If ``True``, returns a
:class:`amici.jax.JAXProblem` instance. If ``False``, returns a
:class:`amici.sim.jax.JAXProblem` instance. If ``False``, returns a
standard AMICI model.

:param kwargs:
Expand Down Expand Up @@ -255,7 +255,7 @@ def import_petab_problem(
)

if jax:
from amici.jax import JAXProblem
from amici.sim.jax import JAXProblem

model = model_module.Model()

Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/pysb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def pysb2jax(
pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise,
)

from amici.jax.ode_export import ODEExporter
from amici.exporters.jax.ode_export import ODEExporter

exporter = ODEExporter(
ode_model,
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/importers/sbml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def sbml2jax(
hybridization=hybridization,
)

from amici.jax.ode_export import ODEExporter
from amici.exporters.jax.ode_export import ODEExporter

exporter = ODEExporter(
ode_model,
Expand Down
37 changes: 0 additions & 37 deletions python/sdist/amici/jax/__init__.py

This file was deleted.

Loading
Loading