From e5de18b0823780a5b30fd2799a3b2e33e2ebef57 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 25 Feb 2026 09:49:57 +0100 Subject: [PATCH] Move amici.jax to amici.exporters.jax and amici.sim.jax Related to #3041. --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 10 +- doc/python_modules.rst | 2 +- python/sdist/amici/_symbolic/de_model.py | 3 +- python/sdist/amici/exporters/jax/__init__.py | 16 ++ .../amici/{ => exporters}/jax/jax.template.py | 18 +- .../{ => exporters}/jax/jaxcodeprinter.py | 0 python/sdist/amici/{ => exporters}/jax/nn.py | 9 +- .../amici/{ => exporters}/jax/nn.template.py | 5 - .../amici/{ => exporters}/jax/ode_export.py | 12 +- .../amici/exporters/sundials/__init__.py | 1 + .../amici/importers/petab/_petab_importer.py | 2 +- .../amici/importers/petab/v1/petab_import.py | 6 +- python/sdist/amici/importers/pysb/__init__.py | 2 +- python/sdist/amici/importers/sbml/__init__.py | 2 +- python/sdist/amici/jax/__init__.py | 37 ---- python/sdist/amici/sim/jax/__init__.py | 161 +++--------------- .../sdist/amici/{ => sim}/jax/_simulation.py | 0 python/sdist/amici/{ => sim}/jax/model.py | 2 +- python/sdist/amici/{ => sim}/jax/petab.py | 141 ++++++++++++++- python/tests/test_jax.py | 11 +- python/tests/test_sciml.py | 4 +- .../test_petab_benchmark_jax.py | 2 +- tests/petab_test_suite/test_petab_suite.py | 2 +- tests/petab_test_suite/test_petab_v2_suite.py | 4 +- tests/sbml/testSBMLSuite.py | 4 +- tests/sbml/testSBMLSuiteJax.py | 4 +- tests/sciml/test_sciml.py | 7 +- 27 files changed, 224 insertions(+), 243 deletions(-) create mode 100644 python/sdist/amici/exporters/jax/__init__.py rename python/sdist/amici/{ => exporters}/jax/jax.template.py (89%) rename python/sdist/amici/{ => exporters}/jax/jaxcodeprinter.py (100%) rename python/sdist/amici/{ => exporters}/jax/nn.py (98%) rename python/sdist/amici/{ => exporters}/jax/nn.template.py (87%) rename python/sdist/amici/{ => exporters}/jax/ode_export.py (97%) delete mode 100644 python/sdist/amici/jax/__init__.py rename python/sdist/amici/{ => sim}/jax/_simulation.py (100%) rename python/sdist/amici/{ => sim}/jax/model.py (99%) rename python/sdist/amici/{ => sim}/jax/petab.py (93%) diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index a12b4d92fa..5ade5cf528 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -21,7 +21,7 @@ "\n", "To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n", "\n", - "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" + "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.importers.petab.PetabImporter](https://amici.readthedocs.io/en/latest/generated/amici.importers.petab.html#amici.importers.petab.PetabImporter), we will load the PEtab problem. To create a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.JAXProblem) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" ] }, { @@ -65,7 +65,7 @@ "source": [ "## Simulation\n", "\n", - "We can now run efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." + "We can now run efficient simulation using [amici.sim.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.run_simulations)." ] }, { @@ -75,7 +75,7 @@ "metadata": {}, "outputs": [], "source": [ - "from amici.jax import run_simulations\n", + "from amici.sim.jax import run_simulations\n", "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" @@ -249,7 +249,7 @@ "source": [ "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", "\n", - "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." + "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.sim.jax.html#amici.sim.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." ] }, { @@ -386,7 +386,7 @@ "import diffrax\n", "import jax.numpy as jnp\n", "import optimistix\n", - "from amici.jax import ReturnValue\n", + "from amici.sim.jax import ReturnValue\n", "\n", "# Define the simulation condition\n", "experiment_condition = \"_petab_experiment_condition___default__\"\n", diff --git a/doc/python_modules.rst b/doc/python_modules.rst index 7eeb000996..92594870b7 100644 --- a/doc/python_modules.rst +++ b/doc/python_modules.rst @@ -15,7 +15,7 @@ AMICI Python API amici.importers.petab amici.importers.petab.v1 amici.importers.utils - amici.jax + amici.sim.jax amici.sim.sundials amici.sim.sundials.plotting amici.sim.sundials.gradient_check diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index d6e96c708c..7217bc8be9 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2499,7 +2499,8 @@ def _process_hybridization(self, hybridization: dict) -> None: orig_obs = tuple([s.get_sym() for s in self._observables]) for net_id, net in hybridization.items(): if net["static"]: - continue # do not integrate into ODEs, handle in amici.jax.petab + # do not integrate into ODEs, handle in amici.sim.jax.petab + continue inputs = [ comp for comp in self._components diff --git a/python/sdist/amici/exporters/jax/__init__.py b/python/sdist/amici/exporters/jax/__init__.py new file mode 100644 index 0000000000..3b9cd14347 --- /dev/null +++ b/python/sdist/amici/exporters/jax/__init__.py @@ -0,0 +1,16 @@ +""" +Code generation for JAX models for simulation with diffrax solvers. + +This module provides an interface to generate and use AMICI models with JAX. +Please note that this module is experimental, the API may substantially change +in the future. Use at your own risk and do not expect backward compatibility. +""" + +from .nn import Flatten, cat, generate_equinox, tanhshrink + +__all__ = [ + "Flatten", + "generate_equinox", + "tanhshrink", + "cat", +] diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/exporters/jax/jax.template.py similarity index 89% rename from python/sdist/amici/jax/jax.template.py rename to python/sdist/amici/exporters/jax/jax.template.py index 063adc5045..37b6339d5a 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/exporters/jax/jax.template.py @@ -1,16 +1,16 @@ # ruff: noqa: F401, F821, F841 from pathlib import Path -import equinox as eqx +import equinox as eqx # noqa: F401 import jax.numpy as jnp -import jax.random as jr -import jaxtyping as jt -from interpax import interp1d -from jax.numpy import inf as oo -from jax.numpy import nan as nan - -from amici import _module_from_path -from amici.jax.model import JAXModel, safe_div, safe_log +import jax.random as jr # noqa: F401 +import jaxtyping as jt # noqa: F401 +from interpax import interp1d # noqa: F401 +from jax.numpy import inf as oo # noqa: F401 +from jax.numpy import nan as nan # noqa: F401 + +from amici import _module_from_path # noqa: F401 +from amici.sim.jax.model import JAXModel, safe_div, safe_log # noqa: F401 TPL_NET_IMPORTS diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/exporters/jax/jaxcodeprinter.py similarity index 100% rename from python/sdist/amici/jax/jaxcodeprinter.py rename to python/sdist/amici/exporters/jax/jaxcodeprinter.py diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/exporters/jax/nn.py similarity index 98% rename from python/sdist/amici/jax/nn.py rename to python/sdist/amici/exporters/jax/nn.py index 28bda1249a..37f51fd587 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/exporters/jax/nn.py @@ -4,8 +4,7 @@ import jax.numpy as jnp from amici import amiciModulePath - -from ..exporters.template import apply_template +from amici.exporters.template import apply_template class Flatten(eqx.Module): @@ -185,7 +184,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F layer_map = { "Dropout1d": "eqx.nn.Dropout", "Dropout2d": "eqx.nn.Dropout", - "Flatten": "amici.jax.Flatten", + "Flatten": "amici.export.jax.Flatten", } # mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations @@ -321,9 +320,9 @@ def _process_activation_call(node: "Node") -> str: # noqa: F821 "hardtanh": "jax.nn.hard_tanh", "hardsigmoid": "jax.nn.hard_sigmoid", "hardswish": "jax.nn.hard_swish", - "tanhshrink": "amici.jax.tanhshrink", + "tanhshrink": "amici.export.jax.tanhshrink", "softsign": "jax.nn.soft_sign", - "cat": "amici.jax.cat", + "cat": "amici.export.jax.cat", } # Validate hardtanh parameters diff --git a/python/sdist/amici/jax/nn.template.py b/python/sdist/amici/exporters/jax/nn.template.py similarity index 87% rename from python/sdist/amici/jax/nn.template.py rename to python/sdist/amici/exporters/jax/nn.template.py index 6b20a39f1b..208eefe72d 100644 --- a/python/sdist/amici/jax/nn.template.py +++ b/python/sdist/amici/exporters/jax/nn.template.py @@ -1,12 +1,7 @@ # ruff: noqa: F401, F821, F841 import equinox as eqx -import jax -import jax.nn -import jax.numpy as jnp import jax.random as jr -import amici.jax.nn - class TPL_MODEL_ID(eqx.Module): layers: dict diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/exporters/jax/ode_export.py similarity index 97% rename from python/sdist/amici/jax/ode_export.py rename to python/sdist/amici/exporters/jax/ode_export.py index 7d69fa3622..127e772f29 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/exporters/jax/ode_export.py @@ -18,19 +18,17 @@ import sympy as sp -from amici import ( - amiciModulePath, -) from amici._symbolic.de_model import DEModel from amici._symbolic.sympy_utils import ( _monkeypatch_sympy, ) +from amici.exporters.jax.nn import generate_equinox from amici.exporters.sundials.de_export import is_valid_identifier from amici.exporters.template import apply_template -from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str -from amici.jax.model import JAXModel -from amici.jax.nn import generate_equinox from amici.logging import get_logger, log_execution_time, set_log_level +from amici.sim.jax.model import JAXModel + +from .jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str #: python log manager logger = get_logger(__name__, logging.ERROR) @@ -303,7 +301,7 @@ def _generate_jax_code(self) -> None: } apply_template( - Path(amiciModulePath) / "jax" / "jax.template.py", + Path(__file__).parent / "jax.template.py", self.model_path / "__init__.py", tpl_data, ) diff --git a/python/sdist/amici/exporters/sundials/__init__.py b/python/sdist/amici/exporters/sundials/__init__.py index e69de29bb2..2211b15a85 100644 --- a/python/sdist/amici/exporters/sundials/__init__.py +++ b/python/sdist/amici/exporters/sundials/__init__.py @@ -0,0 +1 @@ +"""Code generation for models for simulation with SUNDIALS solvers.""" diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index bcb8633c51..9f1c056e80 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -23,8 +23,8 @@ from amici import get_model_dir from amici._symbolic import DEModel, Event from amici.importers.utils import MeasurementChannel, amici_time_symbol -from amici.jax.petab import JAXProblem from amici.logging import get_logger +from amici.sim.jax.petab import JAXProblem from .v1.sbml_import import _add_global_parameter diff --git a/python/sdist/amici/importers/petab/v1/petab_import.py b/python/sdist/amici/importers/petab/v1/petab_import.py index 137b1b456d..d15d23a50f 100644 --- a/python/sdist/amici/importers/petab/v1/petab_import.py +++ b/python/sdist/amici/importers/petab/v1/petab_import.py @@ -47,7 +47,7 @@ def import_petab_problem( non_estimated_parameters_as_constants=True, jax=False, **kwargs, -) -> "amici.sim.sundials.Model | amici.jax.JAXProblem": +) -> "amici.sim.sundials.Model | amici.sim.jax.JAXProblem": """ Create an AMICI model for a PEtab problem. @@ -75,7 +75,7 @@ def import_petab_problem( :param jax: Whether to create a JAX-based problem. If ``True``, returns a - :class:`amici.jax.JAXProblem` instance. If ``False``, returns a + :class:`amici.sim.jax.JAXProblem` instance. If ``False``, returns a standard AMICI model. :param kwargs: @@ -255,7 +255,7 @@ def import_petab_problem( ) if jax: - from amici.jax import JAXProblem + from amici.sim.jax import JAXProblem model = model_module.Model() diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index 38c5ac2d9b..2e3665f086 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -147,7 +147,7 @@ def pysb2jax( pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise, ) - from amici.jax.ode_export import ODEExporter + from amici.exporters.jax.ode_export import ODEExporter exporter = ODEExporter( ode_model, diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 5584ba5657..d444a3999e 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -526,7 +526,7 @@ def sbml2jax( hybridization=hybridization, ) - from amici.jax.ode_export import ODEExporter + from amici.exporters.jax.ode_export import ODEExporter exporter = ODEExporter( ode_model, diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py deleted file mode 100644 index 5c2e24fb31..0000000000 --- a/python/sdist/amici/jax/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -JAX ---- - -This module provides an interface to generate and use AMICI models with JAX. Please note that this module is -experimental, the API may substantially change in the future. Use at your own risk and do not expect backward -compatibility. -""" - -from warnings import warn - -from amici.jax.model import JAXModel -from amici.jax.nn import Flatten, cat, generate_equinox, tanhshrink -from amici.jax.petab import ( - JAXProblem, - ReturnValue, - petab_simulate, - run_simulations, -) - -warn( - "The JAX module is experimental and the API may change in the future.", - ImportWarning, - stacklevel=2, -) - -__all__ = [ - "JAXModel", - "JAXProblem", - "Flatten", - "generate_equinox", - "run_simulations", - "petab_simulate", - "ReturnValue", - "tanhshrink", - "cat", -] diff --git a/python/sdist/amici/sim/jax/__init__.py b/python/sdist/amici/sim/jax/__init__.py index 29b0399c9f..efdc367456 100644 --- a/python/sdist/amici/sim/jax/__init__.py +++ b/python/sdist/amici/sim/jax/__init__.py @@ -1,138 +1,23 @@ -"""Functionality for simulating JAX-based AMICI models.""" - -import jax.numpy as jnp -import pandas as pd -import petab.v2 as petabv2 - - -def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): - """Add default experiment names to PEtab v2 problem. - - Args: - petab_problem: PEtab v2 problem to modify. - """ - if not hasattr(petab_problem, "extensions_config"): - petab_problem.extensions_config = {} - - petab_problem.visualization_df = None - - if petab_problem.condition_df is None: - default_condition = petabv2.core.Condition( - id="__default__", changes=[], conditionId="__default__" - ) - petab_problem.condition_tables[0].elements = [default_condition] - - if ( - petab_problem.experiment_df is None - or petab_problem.experiment_df.empty - ): - condition_ids = petab_problem.condition_df[ - petabv2.C.CONDITION_ID - ].values - condition_ids = [ - c for c in condition_ids if "preequilibration" not in c - ] - default_experiment = petabv2.core.Experiment( - id="__default__", - periods=[ - petabv2.core.ExperimentPeriod( - time=0.0, condition_ids=condition_ids - ) - ], - ) - petab_problem.experiment_tables[0].elements = [default_experiment] - - measurement_tables = petab_problem.measurement_tables.copy() - for mt in measurement_tables: - for m in mt.elements: - m.experiment_id = "__default__" - - petab_problem.measurement_tables = measurement_tables - - return petab_problem - - -def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: - """Get simulation conditions from PEtab v2 measurement DataFrame. - - Returns: - A pandas DataFrame mapping experiment_ids to condition ids. - """ - experiment_df = petab_problem.experiment_df - exps = {} - for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique(): - exps[exp_id] = experiment_df[ - experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id - ][petabv2.C.CONDITION_ID].unique() - - experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) - return experiment_df - - -def _build_simulation_df_v2(problem, y, dyn_conditions): - """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" - dfs = [] - for ic, sc in enumerate(dyn_conditions): - experiment_id = _conditions_to_experiment_map( - problem._petab_problem.experiment_df - )[sc] - - if experiment_id == "__default__": - experiment_id = jnp.nan - - obs = [ - problem.model.observable_ids[io] - for io in problem._iys[ic, problem._ts_masks[ic, :]] - ] - t = jnp.concat( - ( - problem._ts_dyn[ic, :], - problem._ts_posteq[ic, :], - ) - ) - df_sc = pd.DataFrame( - { - petabv2.C.MODEL_ID: [float("nan")] * len(t), - petabv2.C.OBSERVABLE_ID: obs, - petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t), - petabv2.C.TIME: t[problem._ts_masks[ic, :]], - petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], - }, - index=problem._petab_measurement_indices[ic, :], - ) - if ( - petabv2.C.OBSERVABLE_PARAMETERS - in problem._petab_problem.measurement_df - ): - df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" - )[petabv2.C.OBSERVABLE_PARAMETERS] - ) - if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: - df_sc[petabv2.C.NOISE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" - )[petabv2.C.NOISE_PARAMETERS] - ) - dfs.append(df_sc) - return pd.concat(dfs).sort_index() - - -def _conditions_to_experiment_map( - experiment_df: pd.DataFrame, -) -> dict[str, str]: - condition_to_experiment = { - row.conditionId: row.experimentId for row in experiment_df.itertuples() - } - return condition_to_experiment - - -def _try_float(value): - try: - return float(value) - except Exception as e: - msg = str(e).lower() - if isinstance(e, ValueError) and "could not convert" in msg: - return value - raise +""" +Functionality for simulating JAX-based AMICI models. + +This module provides an interface to generate and use AMICI models with JAX. +Please note that this module is experimental, the API may substantially change +in the future. Use at your own risk and do not expect backward compatibility. +""" + +from .model import JAXModel +from .petab import ( + JAXProblem, + ReturnValue, + petab_simulate, + run_simulations, +) + +__all__ = [ + "JAXModel", + "JAXProblem", + "run_simulations", + "petab_simulate", + "ReturnValue", +] diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/sim/jax/_simulation.py similarity index 100% rename from python/sdist/amici/jax/_simulation.py rename to python/sdist/amici/sim/jax/_simulation.py diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/sim/jax/model.py similarity index 99% rename from python/sdist/amici/jax/model.py rename to python/sdist/amici/sim/jax/model.py index b2274e1fb1..bc75dd388b 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/sim/jax/model.py @@ -16,7 +16,7 @@ import jaxtyping as jt from optimistix import AbstractRootFinder -from ._simulation import _apply_event_assignments, eq, solve +from amici.sim.jax._simulation import _apply_event_assignments, eq, solve class ReturnValue(enum.Enum): diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/sim/jax/petab.py similarity index 93% rename from python/sdist/amici/jax/petab.py rename to python/sdist/amici/sim/jax/petab.py index ae2172c81d..7a214af567 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/sim/jax/petab.py @@ -25,14 +25,8 @@ from amici.importers.petab.v1.parameter_mapping import ( ParameterMappingForCondition, ) -from amici.jax.model import JAXModel, ReturnValue from amici.logging import get_logger -from amici.sim.jax import ( - _build_simulation_df_v2, - _try_float, - add_default_experiment_names_to_v2_problem, - get_simulation_conditions_v2, -) +from amici.sim.jax.model import JAXModel, ReturnValue DEFAULT_CONTROLLER_SETTINGS = { "atol": 1e-8, @@ -1955,3 +1949,136 @@ def petab_simulate( ) dfs.append(df_sc) return pd.concat(dfs).sort_index() + + +def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): + """Add default experiment names to PEtab v2 problem. + + Args: + petab_problem: PEtab v2 problem to modify. + """ + if not hasattr(petab_problem, "extensions_config"): + petab_problem.extensions_config = {} + + petab_problem.visualization_df = None + + if petab_problem.condition_df is None: + default_condition = petabv2.core.Condition( + id="__default__", changes=[], conditionId="__default__" + ) + petab_problem.condition_tables[0].elements = [default_condition] + + if ( + petab_problem.experiment_df is None + or petab_problem.experiment_df.empty + ): + condition_ids = petab_problem.condition_df[ + petabv2.C.CONDITION_ID + ].values + condition_ids = [ + c for c in condition_ids if "preequilibration" not in c + ] + default_experiment = petabv2.core.Experiment( + id="__default__", + periods=[ + petabv2.core.ExperimentPeriod( + time=0.0, condition_ids=condition_ids + ) + ], + ) + petab_problem.experiment_tables[0].elements = [default_experiment] + + measurement_tables = petab_problem.measurement_tables.copy() + for mt in measurement_tables: + for m in mt.elements: + m.experiment_id = "__default__" + + petab_problem.measurement_tables = measurement_tables + + return petab_problem + + +def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: + """Get simulation conditions from PEtab v2 measurement DataFrame. + + Returns: + A pandas DataFrame mapping experiment_ids to condition ids. + """ + experiment_df = petab_problem.experiment_df + exps = {} + for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique(): + exps[exp_id] = experiment_df[ + experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id + ][petabv2.C.CONDITION_ID].unique() + + experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) + return experiment_df + + +def _build_simulation_df_v2(problem, y, dyn_conditions): + """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" + dfs = [] + for ic, sc in enumerate(dyn_conditions): + experiment_id = _conditions_to_experiment_map( + problem._petab_problem.experiment_df + )[sc] + + if experiment_id == "__default__": + experiment_id = jnp.nan + + obs = [ + problem.model.observable_ids[io] + for io in problem._iys[ic, problem._ts_masks[ic, :]] + ] + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) + ) + df_sc = pd.DataFrame( + { + petabv2.C.MODEL_ID: [float("nan")] * len(t), + petabv2.C.OBSERVABLE_ID: obs, + petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t), + petabv2.C.TIME: t[problem._ts_masks[ic, :]], + petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], + }, + index=problem._petab_measurement_indices[ic, :], + ) + if ( + petabv2.C.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.OBSERVABLE_PARAMETERS] + ) + if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petabv2.C.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.NOISE_PARAMETERS] + ) + dfs.append(df_sc) + return pd.concat(dfs).sort_index() + + +def _conditions_to_experiment_map( + experiment_df: pd.DataFrame, +) -> dict[str, str]: + condition_to_experiment = { + row.conditionId: row.experimentId for row in experiment_df.itertuples() + } + return condition_to_experiment + + +def _try_float(value): + try: + return float(value) + except Exception as e: + msg = str(e).lower() + if isinstance(e, ValueError) and "could not convert" in msg: + return value + raise diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index d6cdb45c1b..d501894932 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -4,7 +4,6 @@ import pytest pytest.importorskip("jax") -import amici.jax import diffrax import jax import jax.numpy as jnp @@ -15,7 +14,7 @@ from amici import import_model_module from amici.importers.petab.v1 import import_petab_problem from amici.importers.pysb import pysb2amici, pysb2jax -from amici.jax import JAXProblem, ReturnValue, run_simulations +from amici.sim.jax import JAXProblem, ReturnValue, run_simulations from amici.sim.sundials import ( ExpData, SensitivityMethod, @@ -342,8 +341,8 @@ def test_time_dependent_discontinuity(tmp_path): from amici.importers.antimony import antimony2sbml from amici.importers.sbml import SbmlImporter - from amici.jax._simulation import solve - from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS + from amici.sim.jax._simulation import solve + from amici.sim.jax.petab import DEFAULT_CONTROLLER_SETTINGS ant_model = """ model time_disc @@ -408,8 +407,8 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): from amici.importers.antimony import antimony2sbml from amici.importers.sbml import SbmlImporter - from amici.jax._simulation import eq - from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS + from amici.sim.jax._simulation import eq + from amici.sim.jax.petab import DEFAULT_CONTROLLER_SETTINGS ant_model = """ model time_disc_eq diff --git a/python/tests/test_sciml.py b/python/tests/test_sciml.py index 84af564bfe..0403e1586e 100644 --- a/python/tests/test_sciml.py +++ b/python/tests/test_sciml.py @@ -8,7 +8,7 @@ pytest.importorskip("equinox") import pytest -from amici.jax.nn import ( +from amici.exporters.jax.nn import ( _format_function_call, _generate_forward, _process_activation_call, @@ -229,7 +229,7 @@ def test_mapped_activation_tanhshrink(self): node.kwargs = {} fun_str = _process_activation_call(node) - assert fun_str == "amici.jax.tanhshrink" + assert fun_str == "amici.export.jax.tanhshrink" def test_hardtanh_valid_params(self): """Test hardtanh with valid default parameters.""" diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index 084ec787bb..2751179738 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -10,7 +10,7 @@ from amici.importers.petab.v1 import ( import_petab_problem, ) -from amici.jax.petab import run_simulations +from amici.sim.jax.petab import run_simulations from amici.sim.sundials import SensitivityMethod, SensitivityOrder from amici.sim.sundials.petab.v1 import ( LLH, diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 581201713e..8c28a55f8e 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -84,7 +84,7 @@ def _test_case(case, model_type, version, jax): jax=jax, ) if jax: - from amici.jax import petab_simulate, run_simulations + from amici.sim.jax import petab_simulate, run_simulations steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) jax_problem = ( diff --git a/tests/petab_test_suite/test_petab_v2_suite.py b/tests/petab_test_suite/test_petab_v2_suite.py index 6895da6149..78038ad38a 100755 --- a/tests/petab_test_suite/test_petab_v2_suite.py +++ b/tests/petab_test_suite/test_petab_v2_suite.py @@ -71,8 +71,8 @@ def _test_case(case, model_type, version, jax): ) if jax: - from amici.jax import petab_simulate, run_simulations - from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS + from amici.sim.jax import petab_simulate, run_simulations + from amici.sim.jax.petab import DEFAULT_CONTROLLER_SETTINGS steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) diff --git a/tests/sbml/testSBMLSuite.py b/tests/sbml/testSBMLSuite.py index 6b596bcffa..b89d266db5 100755 --- a/tests/sbml/testSBMLSuite.py +++ b/tests/sbml/testSBMLSuite.py @@ -23,7 +23,7 @@ import optimistix import pandas as pd import pytest -from amici.jax.petab import ( +from amici.sim.jax.petab import ( DEFAULT_CONTROLLER_SETTINGS, DEFAULT_ROOT_FINDER_SETTINGS, ) @@ -240,7 +240,7 @@ def simulate(pars): diffrax.DirectAdjoint(), diffrax.SteadyStateEvent(), 2**10, - ret=amici.jax.ReturnValue.x, + ret=amici.sim.jax.ReturnValue.x, ) return x diff --git a/tests/sbml/testSBMLSuiteJax.py b/tests/sbml/testSBMLSuiteJax.py index 772c881e49..4bf9f34e2b 100644 --- a/tests/sbml/testSBMLSuiteJax.py +++ b/tests/sbml/testSBMLSuiteJax.py @@ -12,7 +12,7 @@ import optimistix import pandas as pd import pytest -from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS +from amici.sim.jax.petab import DEFAULT_CONTROLLER_SETTINGS from amici.sim.sundials import AMICI_SUCCESS from utils import ( find_model_file, @@ -87,7 +87,7 @@ def run_jax_simulation(model, importer, ts, atol, rtol, tol_factor=1e2): diffrax.DirectAdjoint(), diffrax.SteadyStateEvent(), 2**10, - ret=amici.jax.ReturnValue.x, + ret=amici.sim.jax.ReturnValue.x, ) y = jax.vmap( lambda t, x_solver, x_rdata, hs: model._y( diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index cdeccd630f..9bda8c24f6 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -13,12 +13,9 @@ import pandas as pd import petab.v1 as petab import pytest +from amici.exporters.jax import generate_equinox from amici.importers.petab.v1 import import_petab_problem -from amici.jax import ( - generate_equinox, - petab_simulate, - run_simulations, -) +from amici.sim.jax import petab_simulate, run_simulations from petab_sciml import NNModelStandard from yaml import safe_load