From b25ca24737ad7b1e4e6f042a1a84c2f03627f6a6 Mon Sep 17 00:00:00 2001 From: ntfrgl Date: Thu, 30 Apr 2026 10:59:54 -0700 Subject: [PATCH 1/4] Introduce `ecoli.library.xarray_emitter` --- .gitignore | 2 +- configs/test_configs/test_xarray_emitter.json | 117 +++ doc/apidoc_templates/module.rst_t | 2 - doc/composites.rst | 9 +- doc/conf.py | 44 +- doc/experiments.rst | 3 + doc/stores.rst | 2 + doc/workflows.rst | 5 + ecoli/__init__.py | 2 + ecoli/experiments/ecoli_master_sim.py | 196 +++-- ecoli/library/emitter.py | 168 ++++ ecoli/library/parquet_emitter.py | 102 ++- ecoli/library/test_parquet_emitter.py | 284 +++---- ecoli/library/test_utils.py | 99 +++ ecoli/library/xarray_emitter/__init__.py | 288 +++++++ ecoli/library/xarray_emitter/emit_path.py | 136 +++ .../library/xarray_emitter/emit_predicate.py | 225 +++++ ecoli/library/xarray_emitter/emitter.py | 233 +++++ ecoli/library/xarray_emitter/storage.py | 369 ++++++++ .../xarray_emitter/test_xarray_emitter.py | 300 +++++++ ecoli/library/xarray_emitter/transducer.py | 601 +++++++++++++ ecoli/library/xarray_emitter/utils.py | 65 ++ ecoli/library/xarray_emitter/view.py | 386 +++++++++ ecoli/library/xarray_emitter/writer.py | 523 ++++++++++++ ecoli/library/xarray_emitter/zarr_writer.py | 801 ++++++++++++++++++ ecoli/processes/engine_process.py | 27 +- ecoli/processes/listeners/mass_listener.py | 4 +- pyproject.toml | 8 + pytest.ini | 6 + runscripts/test_workflow.py | 304 ++++++- uv.lock | 80 +- 31 files changed, 5086 insertions(+), 305 deletions(-) create mode 100644 configs/test_configs/test_xarray_emitter.json create mode 100644 ecoli/library/emitter.py create mode 100644 ecoli/library/test_utils.py create mode 100644 ecoli/library/xarray_emitter/__init__.py create mode 100644 ecoli/library/xarray_emitter/emit_path.py create mode 100644 ecoli/library/xarray_emitter/emit_predicate.py create mode 100644 ecoli/library/xarray_emitter/emitter.py create mode 100644 ecoli/library/xarray_emitter/storage.py create mode 100644 ecoli/library/xarray_emitter/test_xarray_emitter.py create mode 100644 ecoli/library/xarray_emitter/transducer.py create mode 100644 ecoli/library/xarray_emitter/utils.py create mode 100644 ecoli/library/xarray_emitter/view.py create mode 100644 ecoli/library/xarray_emitter/writer.py create mode 100644 ecoli/library/xarray_emitter/zarr_writer.py diff --git a/.gitignore b/.gitignore index cf55a3998..514d3deff 100644 --- a/.gitignore +++ b/.gitignore @@ -69,4 +69,4 @@ trace-* test_sherlock/ # SMS API # -.hpc_env \ No newline at end of file +.hpc_env diff --git a/configs/test_configs/test_xarray_emitter.json b/configs/test_configs/test_xarray_emitter.json new file mode 100644 index 000000000..0c4532c71 --- /dev/null +++ b/configs/test_configs/test_xarray_emitter.json @@ -0,0 +1,117 @@ +{ + "experiment_id": "test_xarray_emitter", + "fixed_media": "minimal", + "suffix_time": false, + "max_duration": 10.0, + "fail_at_max_duration": false, + "generations": 1, + "n_init_sims": 1, + "log_updates": true, + "emitter": "xarray", + "emitter_arg": { + "debug": false, + "transducer": { + "predicate": [ + [ + {"subsample": {"interval": 1}}, + {"fixed": {"steps": [0]}} + ] + ], + "buffer": { + "size": 3 + } + }, + "writer": { + "store": "out/store", + "threaded": true, + "buffers_per_chunk": 2, + "backend": "zarr", + "backend_config": { + "format": 3, + "async.concurrency": 3, + "threading.max_workers": 3 + } + }, + "view": [ + { + "root": [], + "variables": { + "bulk": [{ + "path": "bulk/bulk_molecule", + "dtype": "` or the + :py:mod:`.xarray_emitter`. + ------------- Initial State diff --git a/doc/conf.py b/doc/conf.py index 9abcc5c31..06ffe1a0e 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -86,6 +86,20 @@ # Silence warning in ecoli.processes.environment.field_timeline.FieldTimeline ("py:class", "vivarium.processes.timeline.TimelineProcess"), ("py:class", "concurrent.futures._base.Future"), + # Type annotations using library internals + ("py:class", "concurrent.futures._base.Executor"), + ("py:class", "unittest.mock._patch"), + ("py:class", "xarray.backends.common.ArrayWriter"), + ("py:class", "xarray.core.treenode.NodePath"), + ("py:class", "zarr.core.group.ConsolidatedMetadata"), + ("py:class", "zarr.core._tree.TreeRepr"), + # Sphinx does not recognize type parameters in generic classes + ("py:class", "ArrT"), + ("py:class", "NodeT"), + ("py:class", "StoreT"), + # Sphinx does not recognize type aliases + ("py:type", "VariableEncoding"), + ("py:class", "VariableEncoding"), ] @@ -115,15 +129,17 @@ # -- sphinx.ext.intersphinx options -- intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "vivarium": ( - "https://vivarium-core.readthedocs.io/en/latest/", - None, - ), + "pytest": ("https://docs.pytest.org/en/latest", None), + "vivarium": ("https://vivarium-core.readthedocs.io/en/latest/", None), "numpy": ("https://numpy.org/doc/stable", None), + "xarray": ("https://docs.xarray.dev/en/latest", None), + "zarr": ("https://zarr.readthedocs.io/en/latest", None), "matplotlib": ("https://matplotlib.org/stable/", None), "pandas": ("http://pandas.pydata.org/pandas-docs/dev", None), "polars": ("https://docs.pola.rs/api/python/stable", None), "sympy": ("https://docs.sympy.org/latest", None), + "pint": ("https://pint.readthedocs.io/en/stable", None), + "unum": ("https://unum.readthedocs.io/en/stable", None), } @@ -143,8 +159,24 @@ ] # Move typehints from signature into description autodoc_typehints = "description" -# Concatenate class and __init__ docstrings -autoclass_content = "both" +# Only use the class’s docstring. __init__ docstrings are now listed separately. +autoclass_content = "class" +# Default options for all autodoc directives. +autodoc_default_options = { + "member-order": "bysource", + "private-members": True, + "special-members": ( + # object + "__init__, __del__, __call__" + ), + "exclude-members": ( + # abc.ABC + "_abc_impl, " + # enum.Flag + "_flag_mask_, _singles_mask_, _all_bits_, _boundary_, _inverted_, " + "_generate_next_value_" + ) +} # Remove domain objects (e.g. functions, classes, attributes) from # table of contents toc_object_entries = False diff --git a/doc/experiments.rst b/doc/experiments.rst index 9e6a423e2..2fe680faa 100644 --- a/doc/experiments.rst +++ b/doc/experiments.rst @@ -396,6 +396,9 @@ Here are some general rules to remember when writing your own JSON config files: without dividing, this results in a more informative error message instead of a Nextflow error about missing daughter cell states. + +.. _experiment_output: + ------ Output ------ diff --git a/doc/stores.rst b/doc/stores.rst index d1a379355..4a28d433d 100644 --- a/doc/stores.rst +++ b/doc/stores.rst @@ -368,6 +368,8 @@ any number of attributes for all active (``_entryState`` is 1) unique molecules of a given type (e.g. RNA, active RNAP, etc.). +.. _listeners: + --------- Listeners --------- diff --git a/doc/workflows.rst b/doc/workflows.rst index d015d0dc0..4a1848e5c 100644 --- a/doc/workflows.rst +++ b/doc/workflows.rst @@ -330,6 +330,7 @@ it has access to. folder, you can just create stub files in the appropriate folders that simply import the ``plot`` function from a primary analysis script. + .. _analysis_config: Configuration @@ -502,6 +503,9 @@ Refer to :ref:`/output.rst` for more information about how to use DuckDB to read and analyze simulation output inside analysis scripts. + +.. _workflows: + --------- Workflows --------- @@ -700,6 +704,7 @@ is a list workflow behaviors enabled in our model to handle unexpected errors. depends on generation 6, :py:mod:`runscripts.create_variants` depends on :py:mod:`runscripts.parca`, etc). + .. _output: ------ diff --git a/ecoli/__init__.py b/ecoli/__init__.py index 673b8391f..24e2c5c2f 100644 --- a/ecoli/__init__.py +++ b/ecoli/__init__.py @@ -6,6 +6,7 @@ ) from ecoli.library.parquet_emitter import ParquetEmitter +from ecoli.library.xarray_emitter.emitter import XarrayEmitter from ecoli.library.schema import ( divide_binomial, divide_bulk, @@ -39,6 +40,7 @@ faulthandler.enable() emitter_registry.register("parquet", ParquetEmitter) +emitter_registry.register("xarray", XarrayEmitter) # register :term:`updaters` inverse_updater_registry.register("accumulate", inverse_update_accumulate) diff --git a/ecoli/experiments/ecoli_master_sim.py b/ecoli/experiments/ecoli_master_sim.py index b77bc2be9..efa2aca2e 100644 --- a/ecoli/experiments/ecoli_master_sim.py +++ b/ecoli/experiments/ecoli_master_sim.py @@ -24,15 +24,16 @@ import numpy as np from fsspec import open as fsspec_open from vivarium.core.engine import Engine -from vivarium.core.composer import deep_merge +from vivarium.core.composer import Composite, deep_merge from vivarium.core.process import Process from vivarium.core.serialize import deserialize_value, serialize_value from vivarium.library.dict_utils import deep_merge_check from vivarium.library.topology import inverse_topology from vivarium.library.topology import assoc_path, get_in from ecoli.library.logging_tools import write_json -from wholecell.utils.filepath import cloud_path_join +from ecoli.library.parquet_emitter import BufferedEmitter import ecoli.composites.ecoli_master +from wholecell.utils.filepath import cloud_path_join # Environment composer for spatial environment sim import ecoli.composites.environment.lattice @@ -42,7 +43,6 @@ from ecoli.processes.registries import topology_registry from configs import CONFIG_DIR_PATH -from ecoli.library.parquet_emitter import ParquetEmitter from ecoli.library.schema import not_a_process from wholecell.utils.filepath import ROOT_PATH @@ -256,6 +256,7 @@ def __init__( self.parser.add_argument( "--experiment_id", action="store", + type=str, help=( "ID for this experiment. A UUID will be generated if " 'this argument is not used and "experiment_id" is null ' @@ -346,8 +347,7 @@ def __init__( "--variant", action="store", help="Name of variant." ) self.parser.add_argument( - "--lineage_seed", - action="store", + "--lineage_seed", action="store", type=int, help="Seed used for first cell in lineage.", ) self.parser.add_argument( @@ -467,15 +467,19 @@ def __init__(self, config: dict[str, Any]): # Keep track of base experiment id # in case multiple simulations are run with suffix_time = True. - self.experiment_id_base = config["experiment_id"] + self.experiment_id: str + self.experiment_id_base: str = config["experiment_id"] self.config = config - self.ecoli = None - """vivarium.core.composer.Composite: Contains the fully instantiated - processes, steps, topologies, and flow necessary to run simulation. - Generated by + self.emitter_config: dict[str, Any] = {} + + self.ecoli: Composite + """ + Contains the fully instantiated processes, steps, topologies, and flow + necessary to run simulation. Generated by :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.build_ecoli` and cleared when :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.run` - is called to potentially free up memory after division.""" + is called to potentially free up memory after division. + """ self.generated_initial_state = None """dict: Fully populated initial state for simulation. Generated by :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.build_ecoli` and @@ -744,60 +748,77 @@ def build_ecoli(self): initial_environment, self.generated_initial_state ) - def update_experiment(self, time_to_update: float = 0.0): + def update_experiment( + self, time_to_update: float = 0.0, finalize: bool = True + ) -> None: """ - Runs the E. coli simulation for a specified amount of time. If the - simulation reaches a division event and ``config['generations']`` is set, - it will save the daughter cell states to JSON files in the directory - specified by ``config['daughter_outdir']``. Also creates a file - ``division_time.sh`` that, when executed, sets the environment variable - ``division_time`` to the time at which division occurred (used in - Nextflow workflow runs). + Run the E. coli simulation for a specified amount of time. If the + simulation reaches a division event during this time and + ``config['generations']`` is set, then :py:meth:`~.persist_generation` + will be called and the Python interpreter will be terminated. + + Called by: :py:meth:`.run` or :py:meth:`.save_states`. """ try: - self.ecoli_experiment.update(time_to_update) + success = False + if time_to_update > 0: + self.ecoli_experiment.update(time_to_update) except DivisionDetected: - state = self.ecoli_experiment.state.get_value(condition=not_a_process) - assert len(state["agents"]) == 2 - # Daughter state should include all of the additional - # non-agent state (e.g. environment state) - non_agent_state = {k: v for k, v in state.items() if k != "agents"} - for i, (agent_id, agent_state) in enumerate(state["agents"].items()): - prepare_save_state(agent_state) - daughter_filename = f"daughter_state_{i}.json" - daughter_path = cloud_path_join(self.daughter_outdir, daughter_filename) - write_json( - daughter_path, - {**non_agent_state, "agents": {agent_id: agent_state}}, - ) - # Write daughter state URI to local file for Nextflow to read - with open(f"daughter_state_{i}_uri.txt", "w") as f: - f.write(daughter_path) - print( - f"Divided at t = {self.ecoli_experiment.global_time} after " - f"{self.ecoli_experiment.global_time - self.initial_global_time} sec." + success = True + self.persist_generation() + finally: + # Don't start new I/O operations during a manual shutdown. + if not isinstance(sys.exception(), KeyboardInterrupt): + if isinstance(emitter := self.ecoli_experiment.emitter, BufferedEmitter): + # Finish writing buffered emits to persistent storage, + # unless called inside the `.save_states()` loop. + if finalize: + emitter.finalize(success=success) + if success: + # Exit so that `.run()` does not raise `TimeLimitError`. + sys.exit() + + def persist_generation(self, *, num_agents: int = 2) -> None: + """ + Upon reaching cell division, save the daughter cell states to JSON files + in the directory specified by ``config['daughter_outdir']``. Also, + create a file ``division_time.sh`` that, when executed, sets the + environment variable ``division_time`` to the time at which division + occurred, as expected by + ``runscripts/nextflow/sim.nf::{simGen0,sim}.output``. + + Called by: :py:meth:`~.update_experiment`. + + Args: + num_agents: Expected number of cells. This argument exists solely + for testing purposes. + """ + state = self.ecoli_experiment.state.get_value(condition=not_a_process) + assert len(state["agents"]) == num_agents + # Daughter state should include all of the additional + # non-agent state (e.g. environment state) + non_agent_state = {k: v for k, v in state.items() if k != "agents"} + for i, (agent_id, agent_state) in enumerate(state["agents"].items()): + prepare_save_state(agent_state) + daughter_filename = f"daughter_state_{i}.json" + daughter_path = cloud_path_join(self.daughter_outdir, daughter_filename) + write_json( + daughter_path, + {**non_agent_state, "agents": {agent_id: agent_state}}, ) - # Nextflow workflows will source division time to determine - # initial global time to use for daughter cells - with open("division_time.sh", "w") as f: - f.write(f"export division_time={self.ecoli_experiment.global_time}") - # Tell Parquet emitter that simulation was successful - if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): - self.ecoli_experiment.emitter.success = True - self.ecoli_experiment.emitter.finalize() - # Exit so that EcoliSim.run() does not raise TimeLimitError - sys.exit() - except: # noqa: E722 - # Finish writing any buffered emits to Parquet files if the simulation - # encounters any error (including KeyboardInterrupt) - # We use a bare except instead of finally because we don't want to - # run finalize() every time update_experiment is called to advance to - # save times in save_states() - if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): - self.ecoli_experiment.emitter.finalize() - raise - - def save_states(self): + # Write daughter state URI to local file for Nextflow to read + with open(f"daughter_state_{i}_uri.txt", "w") as f: + f.write(daughter_path) + print( + f"Divided at t = {self.ecoli_experiment.global_time} after " + f"{self.ecoli_experiment.global_time - self.initial_global_time} sec." + ) + # Nextflow workflows will source division time to determine + # initial global time to use for daughter cells + with open("division_time.sh", "w") as f: + f.write(f"export division_time={self.ecoli_experiment.global_time}") + + def save_states(self) -> None: """ Runs the simulation while saving the states of specific timesteps to files named ``data/vivecoli_t{time}.json``. Invoked by @@ -805,6 +826,10 @@ def save_states(self): if ``config['save'] == True``. State is saved as a JSON that can be reloaded into a simulation as described in :py:meth:`~ecoli.composites.ecoli_master.Ecoli.initial_state`. + + Called by: :py:meth:`.run`. + + Calls: :py:meth:`~.update_experiment`. """ for time in self.save_times: if time > self.max_duration: @@ -818,7 +843,7 @@ def save_states(self): time_to_next_save = self.save_times[i] else: time_to_next_save = self.save_times[i] - self.save_times[i - 1] - self.update_experiment(time_to_next_save) + self.update_experiment(time_to_next_save, finalize=False) time_elapsed = self.save_times[i] state = self.ecoli_experiment.state.get_value(condition=not_a_process) if self.divide: @@ -829,18 +854,19 @@ def save_states(self): write_json("data/vivecoli_t" + str(time_elapsed) + ".json", state) print("Finished saving the state at t = " + str(time_elapsed)) time_remaining = self.max_duration - self.save_times[-1] - if time_remaining: - self.update_experiment(time_remaining) + self.update_experiment(time_remaining) - def run(self): - """Create and run an EcoliSim experiment. If the simulation reaches + def run(self) -> None: + """ + Create and run an EcoliSim experiment. If the simulation reaches the maximum duration specified by ``config['max_duration']``, it will raise a :py:class:`~ecoli.experiments.ecoli_master_sim.TimeLimitError` if ``config['fail_at_max_duration']`` is ``True``. + Calls: :py:meth:`~.update_experiment` or :py:meth:`~.save_states`. + .. WARNING:: - Run :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.build_ecoli` - before calling :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.run`! + Run :py:meth:`~.build_ecoli` before calling :py:meth:`~.run`! """ if self.ecoli is None: raise RuntimeError( @@ -857,15 +883,22 @@ def run(self): for key, value in self.emitter_arg.items(): self.emitter_config[key] = value if self.emitter == "parquet": - if ("out_dir" not in self.emitter_config) and ( - "out_uri" not in self.emitter_config - ): - raise RuntimeError( + if not any(map(self.emitter_config.__contains__, + ["out_dir", "out_uri"])): + raise KeyError( "Must provide out_dir or out_uri" - " as emitter argument for parquet emitter." - ) + " as emitter argument for parquet emitter.") + elif self.emitter == "xarray": + if not ( + not any(map(self.emitter_config.__contains__, + ["out_dir", "out_uri"])) + and "store" in self.emitter_config.get("writer", {}) + ): + raise KeyError( + "For {\"emitter\": \"xarray\"}, please provide:\n" + " {\"emitter_arg\": {\"writer\": {\"store\": ... }}}") else: - raise RuntimeError( + raise TypeError( "Emitter option must be a string" " representing the emitter type with any additional config" " options under the emitter_arg key." @@ -903,6 +936,10 @@ def run(self): f" != {parse.quote_plus(self.experiment_id)}" ) experiment_config["experiment_id"] = self.experiment_id + # Ensure that `suffix_time` is in effect for all duplicates + # of `experiment_id` + assert metadata["experiment_id"] == self.experiment_id_base + metadata["experiment_id"] = self.experiment_id experiment_config["profile"] = self.profile # Since unique numpy updater is an class method, internal @@ -917,12 +954,11 @@ def run(self): self.ecoli_experiment = Engine(**experiment_config) # Only emit designated stores if specified - if self.config["emit_paths"]: - self.ecoli_experiment.state.set_emit_values([tuple()], False) - self.ecoli_experiment.state.set_emit_values( - self.config["emit_paths"], - True, - ) + if isinstance(emitter := self.ecoli_experiment.emitter, BufferedEmitter): + emitter.reset_emit_flags( + engine=self.ecoli_experiment, + agent=("agents", self.agent_id), + emit_paths=self.config["emit_paths"]) # Clean up unnecessary references self.generated_initial_state = None diff --git a/ecoli/library/emitter.py b/ecoli/library/emitter.py new file mode 100644 index 000000000..095a68b22 --- /dev/null +++ b/ecoli/library/emitter.py @@ -0,0 +1,168 @@ + +""" +Extensions to the :py:class:`~vivarium.core.emitter.Emitter` interface, as used +by :py:class:`.ParquetEmitter` and :py:class:`.XarrayEmitter`. +""" + + +from __future__ import annotations + +from abc import ABC, abstractmethod +from concurrent.futures import Future, Executor +from dataclasses import dataclass, field, replace +from typing import Any, Callable, Self +from urllib import parse +from warnings import warn + +from vivarium.core.types import HierarchyPath +from vivarium.core.engine import Engine +from vivarium.core.emitter import Emitter + + +# ============================================================================== + + +class BlockingExecutor(Executor): + + def __init__(self, *args) -> None: + assert not len(args) + super().__init__() + + def submit(self, fn: Callable, /, *args, **kwargs) -> Future: + """ + Run a function in the current thread, and return a + :py:class:`~concurrent.futures.Future` that is already done. + """ + future: Future = Future() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + return future + + def shutdown(self, wait=True, *, cancel_futures=False) -> None: + pass + + +# ============================================================================== + + +@dataclass(eq=True, kw_only=True, slots=True) +class StoragePartition: + """ + Metadata determining the relative storage location for the simulation + outputs of a single-generation :py:class:`.EcoliSim`, inside a hive + partition or hierarchical store (see :ref:`parquet_emitter`). + """ + + experiment_id: str + variant: int + lineage_seed: int + generation: int = field(init=False) + agent_id: str + + def __post_init__(self) -> None: + assert isinstance(self.experiment_id, str) + assert isinstance(self.variant, int) + assert isinstance(self.lineage_seed, int) + assert isinstance(self.agent_id, str) + self.generation = len(self.agent_id) + assert self.generation > 0 + + @property + def parent(self) -> Self: + """ + Metadata of the mother cell in the same cell lineage. + """ + return replace(self, agent_id=self.agent_id[:-1]) + + +# ============================================================================== + + +class BufferedEmitter(Emitter, ABC): + """ + An extension to the :py:class:`~vivarium.core.emitter.Emitter` interface + that buffers emitted simulation data before writing it to persistent + storage. In particular, this interface is used by + :py:meth:`.EcoliSim.update_experiment` and + :py:meth:`.EngineProcess.next_update`. + + .. warning:: + :py:meth:`~.finalize` must be explicitly called in a + ``try ... finally ...`` block around the call to + :py:meth:`vivarium.core.engine.Engine.update`, in order to ensure that + all buffered emits are written out when the simulation terminates for + any reason. + """ + + def __init__(self) -> None: + """ + .. warning:: + This method should be called **at the end** of a subclass + ``__init__()``. + """ + self.finalized: bool = False + """ + Flag set by :py:meth:`.finalize` after writing the last buffer. + """ + + @abstractmethod + def reset_emit_flags( + self, *, + engine: Engine, agent: HierarchyPath, emit_paths: tuple[HierarchyPath] + ) -> None: + """ + Reconfigure the simulation engine to avoid futile data marshalling, by + suppressing all default emissions and enabling only stores that were + explicitly requested by this emitter's configuration. + + Called by: :py:meth:`.EcoliSim.run` or + :py:meth:`.EngineProcess.create_emitter`. + """ + ... + + def extract_partition(self, metadata: dict[str, Any], /) -> StoragePartition: + """ + Define the current :py:class:`StoragePartition` from the simulation + metadata received via :py:meth:`!Engine._emit_configuration`. + """ + return StoragePartition( + experiment_id=parse.quote_plus( + metadata.get("experiment_id", "default")), + variant=int(metadata.get("variant", 0)), + lineage_seed=int(metadata.get("lineage_seed", 0)), + agent_id=metadata.get("agent_id", "1")) + + def finalize(self, *, success: bool = False) -> None: + """ + Emit the partially filled buffer at the end of a single-generation + simulation. + + Args: + success: Indicates whether the simulation reached a + :py:exc:`.DivisionDetected` event. + """ + if self.finalized: + raise RuntimeError( + f"`{type(self).__name__}.finalize()` was already called.") + assert isinstance(success, bool) + self._finalize(success=success) + self.finalized = True + + @abstractmethod + def _finalize(self, *, success: bool) -> None: + """ + Called by: :py:meth:`.finalize`. + """ + ... + + def __del__(self) -> None: + """ + When a successfully initialised :py:class:`.BufferedEmitter` instance is + destroyed, check that its last batch has been flushed by the simulation + loop. + """ + if not getattr(self, "finalized", True): + warn(f"\n `{type(self).__name__}.finalize()` was never called.") diff --git a/ecoli/library/parquet_emitter.py b/ecoli/library/parquet_emitter.py index 5e2a9251c..607ed3765 100644 --- a/ecoli/library/parquet_emitter.py +++ b/ecoli/library/parquet_emitter.py @@ -1,7 +1,9 @@ + import os import fnmatch from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, cast, Mapping, Optional +from dataclasses import asdict +from typing import Any, Callable, Mapping, Optional, cast, final from urllib import parse import duckdb @@ -12,7 +14,15 @@ from fsspec.core import filesystem, url_to_fs, OpenFile from fsspec.spec import AbstractFileSystem from tqdm import tqdm -from vivarium.core.emitter import Emitter + +from vivarium.core.types import HierarchyPath +from vivarium.core.engine import Engine + +from .emitter import BlockingExecutor, BufferedEmitter + + +# ============================================================================== + METADATA_PREFIX = "output_metadata__" """ @@ -62,6 +72,9 @@ """uint32 is 2x smaller than int64 for values between 0 - 4,294,967,295.""" +# ============================================================================== + + def json_to_parquet( emit_dict: dict[str, np.ndarray | list[pl.Series]], outfile: str, @@ -839,22 +852,11 @@ def pl_dtype_from_ndarray(arr: np.ndarray) -> pl.DataType: return pl_dtype -class BlockingExecutor: - def submit(self, fn: Callable, *args, **kwargs) -> Future: - """ - Run function in the current thread and return a Future that - is already done. - """ - future: Future = Future() - try: - result = fn(*args, **kwargs) - future.set_result(result) - except Exception as e: - future.set_exception(e) - return future +# ============================================================================== -class ParquetEmitter(Emitter): +@final +class ParquetEmitter(BufferedEmitter): """ Emit data to a Parquet dataset. Note that :py:meth:`~.finalize` must be explicitly called in a ``try...finally`` block around the call to @@ -907,22 +909,32 @@ def __init__(self, config: dict[str, Any]) -> None: # was successfully written to Parquet in order to avoid blocking self.last_batch_future: Future = Future() self.last_batch_future.set_result(None) - # Set either by EcoliSim or by EngineProcess if sim reaches division - self.success = False - - def finalize(self): - """Convert remaining batched emits to Parquet at sim shutdown - and mark sim as successful if ``success`` flag was set. In vEcoli, - this is done by :py:class:`~ecoli.experiments.ecoli_master_sim.EcoliSim` - upon reaching division. + super().__init__() + + def reset_emit_flags( + self, *, + engine: Engine, agent: HierarchyPath, emit_paths: tuple[HierarchyPath] + ) -> None: + """ + In this subclass, ``agent`` is ignored and ``emit_paths`` is interpreted + as a global path. + """ + assert engine.emitter is self + if emit_paths: + state = self.ecoli_experiment.state + state.set_emit_value(emit=False, path=tuple()) + state.set_emit_values(emit=True, paths=emit_paths) + + def _finalize(self, *, success: bool): + """ + Convert remaining batched emits to Parquet at sim shutdown and mark sim + as successful if ``success`` flag was set. """ # Wait for last batch to finish writing self.last_batch_future.result() # Flush any remaining buffered emits to Parquet outfile = os.path.join( - self.out_uri, - self.experiment_id, - "history", + self.out_uri, self.experiment_id, "history", self.partitioning_path, f"{self.num_emits}.pq", ) @@ -934,11 +946,9 @@ def finalize(self): self.buffered_emits, outfile, self.pl_types, self.filesystem ) # Hive-partitioned directory that only contains successful sims - if self.success: + if success: success_file = os.path.join( - self.out_uri, - self.experiment_id, - "success", + self.out_uri, self.experiment_id, "success", self.partitioning_path, "s.pq", ) @@ -986,21 +996,10 @@ def emit(self, data: dict[str, Any]): data = {**data["data"].pop("metadata", {}), **data["data"]} data["time"] = data.get("initial_global_time", 0.0) # Manually create filepaths with hive partitioning - agent_id = data.get("agent_id", "1") - quoted_experiment_id = parse.quote_plus( - data.get("experiment_id", "default") - ) - partitioning_keys = { - "experiment_id": quoted_experiment_id, - "variant": data.get("variant", 0), - "lineage_seed": data.get("lineage_seed", 0), - "generation": len(agent_id), - "agent_id": agent_id, - } - self.experiment_id = quoted_experiment_id - self.partitioning_path = os.path.join( - *(f"{k}={v}" for k, v in partitioning_keys.items()) - ) + partition = self.extract_partition(data) + self.partitioning_path = os.path.join(*( + f"{k}={v}" for (k, v) in asdict(partition).items())) + self.experiment_id = partition.experiment_id data = flatten_dict(data) config_emit: dict[str, Any] = {} config_schema: dict[str, pl.DataType] = {} @@ -1014,9 +1013,7 @@ def emit(self, data: dict[str, Any]): config_emit[k] = v config_schema[k] = v.dtype outfile = os.path.join( - self.out_uri, - self.experiment_id, - "configuration", + self.out_uri, self.experiment_id, "configuration", self.partitioning_path, "config.pq", ) @@ -1036,7 +1033,8 @@ def emit(self, data: dict[str, Any]): ) # Delete any sim output files in final filesystem history_outdir = os.path.join( - self.out_uri, self.experiment_id, "history", self.partitioning_path + self.out_uri, self.experiment_id, "history", + self.partitioning_path ) try: self.filesystem.delete(history_outdir, recursive=True) @@ -1124,9 +1122,7 @@ def emit(self, data: dict[str, Any]): # If last batch of emits failed, exception should be raised here self.last_batch_future.result() outfile = os.path.join( - self.out_uri, - self.experiment_id, - "history", + self.out_uri, self.experiment_id, "history", self.partitioning_path, f"{self.num_emits}.pq", ) diff --git a/ecoli/library/test_parquet_emitter.py b/ecoli/library/test_parquet_emitter.py index 526b8d10d..833d7ffaa 100644 --- a/ecoli/library/test_parquet_emitter.py +++ b/ecoli/library/test_parquet_emitter.py @@ -1,7 +1,5 @@ import os import re -import tempfile -import shutil import duckdb import numpy as np import polars as pl @@ -254,7 +252,7 @@ def test_union_pl_dtypes(self): pl.UInt32, ) == pl.List(pl.List(pl.List(pl.UInt32))) - def test_quote_columns(self): + def test_quote_columns(self, tmp_path): """Test quote_columns handles special characters correctly.""" # Test single string with special characters assert quote_columns("simple") == '"simple"' @@ -291,116 +289,114 @@ def test_quote_columns(self): assert quote_columns([]) == [] # Test that quoted columns actually work in DuckDB queries with weird column names - with tempfile.TemporaryDirectory() as tmp_path: - test_file = os.path.join(tmp_path, "weird_cols.parquet") - # Create test data with columns containing special characters - test_data = pl.DataFrame( - { - "simple": [1, 2, 3], - "with spaces": [4, 5, 6], - "with-hyphens": [7, 8, 9], - "with[brackets]": [10, 11, 12], - "with/slashes": [13, 14, 15], - 'has"quote': [16, 17, 18], - "dot.name": [19, 20, 21], - "colon:name": [22, 23, 24], - } - ) - test_data.write_parquet(test_file, statistics=False) - - conn = create_duckdb_conn() - - # Test selecting individual columns with special characters - for col in test_data.columns: - quoted_col = quote_columns(col) - result = conn.sql(f"SELECT {quoted_col} FROM '{test_file}'").pl() - assert result.shape == (3, 1) - assert result.columns[0] == col - expected_values = test_data[col].to_list() - assert result[col].to_list() == expected_values - - # Test selecting multiple columns at once - weird_cols = ["with spaces", "with-hyphens", "with[brackets]", 'has"quote'] - quoted_cols = ", ".join(quote_columns(weird_cols)) - result = conn.sql(f"SELECT {quoted_cols} FROM '{test_file}'").pl() - assert result.shape == (3, 4) - for col in weird_cols: - assert col in result.columns - assert result[col].to_list() == test_data[col].to_list() - - # Test that using WHERE clause works with quoted columns - quoted_space_col = quote_columns("with spaces") - result = conn.sql( - f"SELECT * FROM '{test_file}' WHERE {quoted_space_col} > 4" - ).pl() - assert result.shape == (2, 8) - assert result["with spaces"].to_list() == [5, 6] - - # Test aggregation with quoted columns - quoted_bracket_col = quote_columns("with[brackets]") - result = conn.sql( - f"SELECT AVG({quoted_bracket_col}) as avg_val FROM '{test_file}'" - ).pl() - assert result["avg_val"][0] == 11.0 - - # Test ORDER BY with quoted columns - quoted_slash_col = quote_columns("with/slashes") - result = conn.sql( - f"SELECT {quoted_slash_col} FROM '{test_file}' ORDER BY {quoted_slash_col} DESC" - ).pl() - assert result["with/slashes"].to_list() == [15, 14, 13] - - def test_list_columns(self): + test_file = os.path.join(tmp_path, "weird_cols.parquet") + # Create test data with columns containing special characters + test_data = pl.DataFrame( + { + "simple": [1, 2, 3], + "with spaces": [4, 5, 6], + "with-hyphens": [7, 8, 9], + "with[brackets]": [10, 11, 12], + "with/slashes": [13, 14, 15], + 'has"quote': [16, 17, 18], + "dot.name": [19, 20, 21], + "colon:name": [22, 23, 24], + } + ) + test_data.write_parquet(test_file, statistics=False) + + conn = create_duckdb_conn() + + # Test selecting individual columns with special characters + for col in test_data.columns: + quoted_col = quote_columns(col) + result = conn.sql(f"SELECT {quoted_col} FROM '{test_file}'").pl() + assert result.shape == (3, 1) + assert result.columns[0] == col + expected_values = test_data[col].to_list() + assert result[col].to_list() == expected_values + + # Test selecting multiple columns at once + weird_cols = ["with spaces", "with-hyphens", "with[brackets]", 'has"quote'] + quoted_cols = ", ".join(quote_columns(weird_cols)) + result = conn.sql(f"SELECT {quoted_cols} FROM '{test_file}'").pl() + assert result.shape == (3, 4) + for col in weird_cols: + assert col in result.columns + assert result[col].to_list() == test_data[col].to_list() + + # Test that using WHERE clause works with quoted columns + quoted_space_col = quote_columns("with spaces") + result = conn.sql( + f"SELECT * FROM '{test_file}' WHERE {quoted_space_col} > 4" + ).pl() + assert result.shape == (2, 8) + assert result["with spaces"].to_list() == [5, 6] + + # Test aggregation with quoted columns + quoted_bracket_col = quote_columns("with[brackets]") + result = conn.sql( + f"SELECT AVG({quoted_bracket_col}) as avg_val FROM '{test_file}'" + ).pl() + assert result["avg_val"][0] == 11.0 + + # Test ORDER BY with quoted columns + quoted_slash_col = quote_columns("with/slashes") + result = conn.sql( + f"SELECT {quoted_slash_col} FROM '{test_file}' ORDER BY {quoted_slash_col} DESC" + ).pl() + assert result["with/slashes"].to_list() == [15, 14, 13] + + def test_list_columns(self, tmp_path): """Test list_columns retrieves column names correctly.""" - with tempfile.TemporaryDirectory() as tmp_path: - # Create test Parquet file with known columns - test_file = os.path.join(tmp_path, "test.parquet") - test_data = pl.DataFrame( - { - "col_a": [1, 2, 3], - "col_b": [4.0, 5.0, 6.0], - "listeners__mass__cell_mass": [7.0, 8.0, 9.0], - "listeners__mass__dry_mass": [10.0, 11.0, 12.0], - "listeners__growth__instantaneous_growth_rate": [0.1, 0.2, 0.3], - "bulk": [[1, 2], [3, 4], [5, 6]], - } - ) - test_data.write_parquet(test_file, statistics=False) + # Create test Parquet file with known columns + test_file = os.path.join(tmp_path, "test.parquet") + test_data = pl.DataFrame( + { + "col_a": [1, 2, 3], + "col_b": [4.0, 5.0, 6.0], + "listeners__mass__cell_mass": [7.0, 8.0, 9.0], + "listeners__mass__dry_mass": [10.0, 11.0, 12.0], + "listeners__growth__instantaneous_growth_rate": [0.1, 0.2, 0.3], + "bulk": [[1, 2], [3, 4], [5, 6]], + } + ) + test_data.write_parquet(test_file, statistics=False) - conn = create_duckdb_conn() - subquery = f"SELECT * FROM '{test_file}'" + conn = create_duckdb_conn() + subquery = f"SELECT * FROM '{test_file}'" - # Test getting all columns - all_cols = list_columns(conn, subquery) - assert len(all_cols) == 6 - assert "col_a" in all_cols - assert "col_b" in all_cols - assert "listeners__mass__cell_mass" in all_cols + # Test getting all columns + all_cols = list_columns(conn, subquery) + assert len(all_cols) == 6 + assert "col_a" in all_cols + assert "col_b" in all_cols + assert "listeners__mass__cell_mass" in all_cols - # Test pattern matching with glob patterns - listener_cols = list_columns(conn, subquery, "listeners__*") - assert len(listener_cols) == 3 - assert all(col.startswith("listeners__") for col in listener_cols) + # Test pattern matching with glob patterns + listener_cols = list_columns(conn, subquery, "listeners__*") + assert len(listener_cols) == 3 + assert all(col.startswith("listeners__") for col in listener_cols) - # Test pattern matching for specific listener - mass_cols = list_columns(conn, subquery, "listeners__mass__*") - assert len(mass_cols) == 2 - assert "listeners__mass__cell_mass" in mass_cols - assert "listeners__mass__dry_mass" in mass_cols + # Test pattern matching for specific listener + mass_cols = list_columns(conn, subquery, "listeners__mass__*") + assert len(mass_cols) == 2 + assert "listeners__mass__cell_mass" in mass_cols + assert "listeners__mass__dry_mass" in mass_cols - # Test pattern that matches nothing - no_match = list_columns(conn, subquery, "nonexistent__*") - assert len(no_match) == 0 + # Test pattern that matches nothing + no_match = list_columns(conn, subquery, "nonexistent__*") + assert len(no_match) == 0 - # Test pattern with single character wildcard - col_pattern = list_columns(conn, subquery, "col_?") - assert len(col_pattern) == 2 - assert "col_a" in col_pattern - assert "col_b" in col_pattern + # Test pattern with single character wildcard + col_pattern = list_columns(conn, subquery, "col_?") + assert len(col_pattern) == 2 + assert "col_a" in col_pattern + assert "col_b" in col_pattern - # Test exact match pattern - exact = list_columns(conn, subquery, "bulk") - assert exact == ["bulk"] + # Test exact match pattern + exact = list_columns(conn, subquery, "bulk") + assert exact == ["bulk"] def compare_nested(a: list, b: list) -> bool: @@ -420,21 +416,16 @@ def compare_nested(a: list, b: list) -> bool: class TestParquetEmitter: - @pytest.fixture - def temp_dir(self): - """Create a temporary directory for testing.""" - tmp = tempfile.mkdtemp() - yield tmp - shutil.rmtree(tmp) - def test_initialization(self, temp_dir): + def test_initialization(self, tmp_path): """Test ParquetEmitter initialization with different configs.""" # Test with out_dir - emitter = ParquetEmitter({"out_dir": temp_dir}) + emitter = ParquetEmitter({"out_dir": tmp_path}) emitter.experiment_id = "test_exp" emitter.partitioning_path = "path/to/output" - assert emitter.out_uri == os.path.abspath(temp_dir) + assert emitter.out_uri == os.path.abspath(tmp_path) assert emitter.batch_size == 400 + emitter.finalized = True # Test with out_uri and custom batch size emitter = ParquetEmitter({"out_uri": "gs://bucket/path", "batch_size": 100}) @@ -442,10 +433,11 @@ def test_initialization(self, temp_dir): emitter.partitioning_path = "path/to/output" assert emitter.out_uri == "gs://bucket/path" assert emitter.batch_size == 100 + emitter.finalized = True - def test_emit_configuration(self, temp_dir): + def test_emit_configuration(self, tmp_path): """Test emitting configuration data.""" - emitter = ParquetEmitter({"out_dir": temp_dir}) + emitter = ParquetEmitter({"out_dir": tmp_path}) # Setup ThreadPoolExecutor mock future = Future() @@ -466,6 +458,7 @@ def test_emit_configuration(self, temp_dir): } emitter.emit(config_data) + emitter.finalized = True # Verify partitioning path assert emitter.experiment_id == "test_exp" @@ -477,9 +470,9 @@ def test_emit_configuration(self, temp_dir): args, _ = emitter.executor.submit.call_args assert args[0] == json_to_parquet - def test_emit_simulation_data(self, temp_dir): + def test_emit_simulation_data(self, tmp_path): """Test emitting simulation data with various types.""" - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 2}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 2}) # Configuration emit to initialize variables config_data = { @@ -542,6 +535,7 @@ def test_emit_simulation_data(self, temp_dir): emitter.emit(sim_data1) assert emitter.num_emits == 2 emitter.last_batch_future.result() + emitter.finalized = True # Check output t = pl.read_parquet( @@ -561,9 +555,9 @@ def test_emit_simulation_data(self, temp_dir): assert all(t["nested__value"] == [100] * 2) assert emitter.buffered_emits == {} - def test_variable_length_arrays(self, temp_dir): + def test_variable_length_arrays(self, tmp_path): """Test handling arrays with changing dimensions.""" - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 3}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 3}) # Configuration emit to initialize variables config_data = { "table": "configuration", @@ -629,6 +623,7 @@ def test_variable_length_arrays(self, temp_dir): # Write to Parquet and check output emitter.emit(sim_data2) emitter.last_batch_future.result() + emitter.finalized = True t = pl.read_parquet( os.path.join( @@ -650,9 +645,9 @@ def test_variable_length_arrays(self, temp_dir): [[1], [1, 2], [1, 2, 3]], ] - def test_extreme_data_types(self, temp_dir): + def test_extreme_data_types(self, tmp_path): """Test with extreme data types and edge cases.""" - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 2}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 2}) # Create test data with extreme values and special cases sim_data = { "table": "configuration", @@ -790,6 +785,7 @@ def test_extreme_data_types(self, temp_dir): emitter.emit(sim_data_2) emitter.last_batch_future.result() assert emitter.buffered_emits == {} + emitter.finalized = True out_path = os.path.join( emitter.out_uri, @@ -852,9 +848,9 @@ def test_extreme_data_types(self, temp_dir): f"Mismatch in field {key}" ) - def test_finalize(self, temp_dir): + def test_finalize(self, tmp_path): """Test finalize method that handles remaining data.""" - emitter = ParquetEmitter({"out_dir": temp_dir}) + emitter = ParquetEmitter({"out_dir": tmp_path}) emitter.experiment_id = "test_exp" emitter.partitioning_path = "path/to/output" @@ -886,8 +882,8 @@ def test_finalize(self, temp_dir): assert args[0]["field2"][0] == 20.5 # Test success flag - emitter.success = True - emitter.finalize() + emitter.finalized = False + emitter.finalize(success=True) assert os.path.exists( os.path.join( emitter.out_uri, @@ -898,8 +894,8 @@ def test_finalize(self, temp_dir): ) ) - def test_multiple_agents(self, temp_dir): - emitter = ParquetEmitter({"out_dir": temp_dir}) + def test_multiple_agents(self, tmp_path): + emitter = ParquetEmitter({"out_dir": tmp_path}) emitter.experiment_id = "test_exp" emitter.partitioning_path = "path/to/output" @@ -916,11 +912,12 @@ def test_multiple_agents(self, temp_dir): emitter.emit(sim_data) assert emitter.num_emits == 0 assert emitter.buffered_emits == {} + emitter.finalized = True - def test_batch_processing(self, temp_dir): + def test_batch_processing(self, tmp_path): """Test multiple emits and batch processing.""" # Small batch size for testing - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 3}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 3}) # Configuration emit to initialize variables config_data = { @@ -947,6 +944,7 @@ def test_batch_processing(self, temp_dir): sim_data["data"]["agents"]["agent1"]["value"] = i * 10 emitter.emit(sim_data) emitter.last_batch_future.result() + emitter.finalized = True # Verify batch was processed assert emitter.num_emits == 4 @@ -957,15 +955,9 @@ def test_batch_processing(self, temp_dir): class TestParquetEmitterEdgeCases: - @pytest.fixture - def temp_dir(self): - """Create a temporary directory for testing.""" - tmp = tempfile.mkdtemp() - yield tmp - shutil.rmtree(tmp) @patch("ecoli.library.parquet_emitter.ThreadPoolExecutor") - def test_multithreaded_buffer_clearing(self, mock_executor_class, temp_dir): + def test_multithreaded_buffer_clearing(self, mock_executor_class, tmp_path): """ Test to verify that clearing buffers after submitting to ThreadPoolExecutor doesn't cause race conditions with the worker thread. @@ -1006,7 +998,7 @@ def delayed_execution(): mock_executor_class.return_value = mock_executor # Initialize the emitter with a small batch size - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 2}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 2}) # Configuration emit to initialize variables config_data = { "table": "configuration", @@ -1086,16 +1078,17 @@ def delayed_execution(): # Changed type for field2 to list so should fail with pytest.raises(pl.exceptions.InvalidOperationError): emitter.finalize() + emitter.finalized = True # Cleanup the real executor real_executor.shutdown() - def test_variable_shape_detection_at_boundaries(self, temp_dir): + def test_variable_shape_detection_at_boundaries(self, tmp_path): """ Test the fixed vs variable shape field detection logic specifically at the boundary points (start of sim, after disk write). """ # Use a small batch size to quickly hit the boundary - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 3}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 3}) # Setup: Emit configuration data to intitialize variables config_data = { @@ -1198,6 +1191,7 @@ def test_variable_shape_detection_at_boundaries(self, temp_dir): emitter.emit(sim_data4) emitter.last_batch_future.result() + emitter.finalized = True t = pl.read_parquet( os.path.join( emitter.out_uri, @@ -1224,12 +1218,12 @@ def test_variable_shape_detection_at_boundaries(self, temp_dir): [[1], [2], [3], [4], [5]], ] - def test_expected_failures(self, temp_dir): + def test_expected_failures(self, tmp_path): """ Test a few cases that are expected to fail. """ # Use a small batch size to quickly hit the boundary - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 3}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 3}) # Setup: Emit configuration data to intitialize variables config_data = { @@ -1394,10 +1388,11 @@ def test_expected_failures(self, temp_dir): match=re.escape("cannot parse numpy data type dtype('O')"), ): emitter.emit(sim_data7) + emitter.finalized = True - def test_nested_nullable(self, temp_dir): + def test_nested_nullable(self, tmp_path): """Test handling nullable nested types that increase in depth.""" - emitter = ParquetEmitter({"out_dir": temp_dir, "batch_size": 4}) + emitter = ParquetEmitter({"out_dir": tmp_path, "batch_size": 4}) # Configuration emit to initialize variables config_data = { "table": "configuration", @@ -1518,6 +1513,7 @@ def test_nested_nullable(self, temp_dir): for _ in range(3): emitter.emit(sim_data1) emitter.last_batch_future.result() + emitter.finalized = True # Check output t = pl.read_parquet( diff --git a/ecoli/library/test_utils.py b/ecoli/library/test_utils.py new file mode 100644 index 000000000..1fbdd390e --- /dev/null +++ b/ecoli/library/test_utils.py @@ -0,0 +1,99 @@ + +""" +Utilities for patching execution environments, configurations and functions. +""" + + +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import reduce +from inspect import ismethod +from unittest.mock import Mock, DEFAULT, _patch, patch +from typing import Any + +import pytest + +from ecoli.library.xarray_emitter.utils import WarningFilter + + +# ============================================================================== +# warnings +# ============================================================================== + + +def filter_warnings(filters: list[WarningFilter]) -> Callable[[Callable], Callable]: + """ + Analogue of :py:func:`ecoli.library.xarray_emitter.utils.filter_warnings`, + but with the effect of applying :py:func:`pytest.mark.filterwarnings` + decorators, instead of :py:func:`warnings.filterwarnings` context modifiers. + """ + return (lambda func: reduce( + lambda fun, wf: pytest.mark.filterwarnings(str(wf))(fun), + filters, func)) + + + +# ============================================================================== +# config patching +# ============================================================================== + + +class PatchConfig(ABC): + """ + Test parameter for modifying an already loaded baseline JSON configuration. + """ + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """ + Materialise changes to the JSON configuration. + """ + ... + + +# ============================================================================== +# code patching +# ============================================================================== + + +def patch_func(func: str, *, cb: Callable | None = None) -> _patch: + """ + Create a context manager which patches a module-level function, in order to + trace its calls, and to optionally pre-apply a callback. + + .. note:: + ``func`` is passed as the argument ``target`` to + :py:func:`unittest.mock.patch`. + """ + mocked = None + def side_effect(*args, **kwargs) -> Any: + nonlocal cb, mocked + if cb is not None: + cb(*args, **kwargs) + return mocked.temp_original(*args, **kwargs) # type: ignore[attr-defined] + mocked = patch(func, side_effect=side_effect) + return mocked + + +# ------------------------------------------------------------------------------ + + +def patch_meth( + obj: object, meth: str, *, + cb: Callable | None = None, modargs: Callable | None = None +) -> None: + """ + Patch an object instance method, in order to trace its calls, and to + optionally pre-apply a callback or argument modification. + """ + assert ismethod(getattr(obj, meth)) + assert cb is None or modargs is None + def side_effect(*args, **kwargs): + nonlocal obj, meth, cb, modargs + if modargs is not None: + _args, _kwargs = modargs(obj, *args, **kwargs) + return getattr(obj, meth)._mock_wraps(*_args, **_kwargs) + elif cb is not None: + cb(obj, *args, **kwargs) + return DEFAULT + setattr(obj, meth, Mock(wraps=getattr(obj, meth), side_effect=side_effect)) diff --git a/ecoli/library/xarray_emitter/__init__.py b/ecoli/library/xarray_emitter/__init__.py new file mode 100644 index 000000000..6ccb43e76 --- /dev/null +++ b/ecoli/library/xarray_emitter/__init__.py @@ -0,0 +1,288 @@ + +r""" + +Introduction +============ + +:py:class:`.XarrayEmitter` is an :py:class:`~vivarium.core.emitter.Emitter` +similar to :py:class:`~.ParquetEmitter`, but with a design optimized towards a +different flavour of downstream applications: + +- :py:class:`~.ParquetEmitter` is geared towards emitting a significant fraction + of the simulator state, in a format that supports flexible sparse selections, + `data reductions`_ and time series visualizations, as used in :ref:`analysis + scripts `. +- :py:class:`.XarrayEmitter` is intended for emitting only a pre-selected subset + of statically shaped tensor variables, in a format that supports numerical + algorithms in the high-dimensional and large-sample regime. + +The former type of computations is naturally expressed using `relational query`_ +engines (e.g., :ref:`DuckDB `), whereas the latter type is +naturally expressed using `array programming`_ libraries (e.g., `Cubed`_). Due +to the sheer size of the simulator state, both types may in general require +`out-of-core processing`_ algorithms. + +.. _data reductions: https://en.wikipedia.org/wiki/Data_reduction +.. _relational query: https://en.wikipedia.org/wiki/Relational_database +.. _array programming: https://en.wikipedia.org/wiki/Array_programming +.. _Cubed: https://cubed-dev.github.io/cubed/why-cubed.html +.. _out-of-core processing: https://en.wikipedia.org/wiki/External_memory_algorithm + +In order to facilitate downstream applications based on chunked array +processing, :py:class:`.XarrayEmitter` writes out to any persistent storage +supporting the `Zarr`_ specification, using an in-memory buffer comprised of +`Xarray`_ objects. For optimized throughput, the buffer implements temporal +subsampling, numerical type casting and compression codecs at emission time. +Furthermore, in order to simplify the export of simulation data into external +libraries, the hierarchy of the output `DataTree`_ is decoupled from the +hierarchy of simulation :py:class:`~vivarium.core.store.Store`\ s, using an +output :ref:`variable layout ` specified in the :ref:`simulator +configuration `. + +.. _Xarray: https://xarray.dev/ +.. _Zarr: https://zarr.dev/ +.. _DataTree: https://docs.xarray.dev/en/stable/user-guide/data-structures.html#datatree + + +Comparison with :py:class:`~.ParquetEmitter` +============================================ + +Similarities +------------ + +- Currently only supports simulations of a *single-cell lineage* per + :py:class:`.BufferedEmitter` instance. +- Executes at every time step. +- Buffers emissions into time chunks. +- Uses concurrent threads for writing buffers to persistent storage. +- Produces a hierarchically structured storage layout that supports selective + reading in downstream applications. + +Differences in usage +-------------------- + +- Supports the configuration of *emission predicates*. +- Currently only supports emitting a *static collection* of *statically shaped + tensor variables*. +- Supports *renaming and rearranging* of output variables. +- Requires the configuration of *output data types*. +- Supports the configuration of backend-specific *compression codecs*. +- Supports :ref:`log_updates`, i.e., the emission of individual + :py:class:`~vivarium.core.process.Process` update requests, before they are + aggregated and reallocated by + :py:func:`~ecoli.processes.allocator.calculatePartition` and then applied to + the global cell state by + :py:meth:`~ecoli.processes.partition.PartitionedProcess.evolve_state`. + +.. note:: + See :py:class:`.XarrayEmitter` for an explanation of the JSON configuration + syntax, and ``configs/test_configs/test_xarray_emitter.json`` for a complete + example. + +.. hint:: + As data structures, `DataTree`_\ s could support changes of variable names and + dimensions across time steps. The constraints currently imposed by + :py:class:`.XarrayEmitter` rather serve to enable I/O optimizations for the + intended use cases. When access to variably sized simulation variables is + desired, users have the choice either of implementing custom :ref:`listeners + ` with static output coordinates, or otherwise of defaulting to the + :py:class:`~.ParquetEmitter`. + +Differences in implementation +----------------------------- + +- Uses the `Xarray`_ API for serialization, buffering, and `metadata + organization`_, including unit annotations (see :py:class:`.VariableSpec` and + :py:class:`.XarrayTransducer` for details). +- Applies a "*process*-major" rather than a "*generation*-major" output layout, + reflecting array variables directly in the output directory tree; this + produces one file per *variable* time chunk, rather than one file per + *simulation* time chunk (compare the :py:class:`.XarrayEmitter` :ref:`storage + layout ` with :py:meth:`.ParquetEmitter.emit`; see + :py:class:`.XarrayStoragePartition` for details). +- Defines the abstract interface :py:class:`.AsyncBufferWriter` for storage + backends with *asynchronous* APIs (currently supported: `Zarr`_), realizing + the opportunity for :ref:`concurrency ` among multiple + `DataArray`_\ s within an output buffer. +- Decouples the in-memory buffer size from the persistent chunk size, in order + to simplify performance tuning of large-scale simulations (see + :py:class:`.XarrayTransducer` and :py:class:`.AsyncBufferWriter` for details). +- Maintains `consolidated metadata`_ and updates it at the end of each simulated + cell generation, in order to reduce the metadata loading latency for + subsequent storage reads (see :py:class:`.AsyncZarrBufferWriter` for details). + +.. _metadata organization: https://docs.xarray.dev/en/stable/get-help/faq.html#approach-to-metadata +.. _DataArray: https://docs.xarray.dev/en/stable/user-guide/data-structures.html#dataarray +.. _consolidated metadata: https://docs.xarray.dev/en/stable/user-guide/io.html#io-zarr-consolidated-metadata + + +.. _storage_layout: + +Storage layout +============== + +The workflow storage layout, which comprises many individual simulations, is +currently organized as follows --- where file paths in this example are specific +to the Zarr v3 storage backend:: + + {store} ; + ├─ zarr.json ; metadata + └─ experiment_id={}/variant={}/lineage_seed={} ; + ├─ zarr.json ; consolidated metadata + ├─ emitstep_gen={} ;