diff --git a/doc/changes/DM-54879.feature.md b/doc/changes/DM-54879.feature.md new file mode 100644 index 000000000..0a50adcb0 --- /dev/null +++ b/doc/changes/DM-54879.feature.md @@ -0,0 +1 @@ +Added `retry_missing_outputs_for` parameter to `QuantumGraphBuilder` and `SeparablePipelineExecutor`. diff --git a/python/lsst/pipe/base/quantum_graph_builder.py b/python/lsst/pipe/base/quantum_graph_builder.py index f230134f4..e7b289541 100644 --- a/python/lsst/pipe/base/quantum_graph_builder.py +++ b/python/lsst/pipe/base/quantum_graph_builder.py @@ -128,6 +128,14 @@ class QuantumGraphBuilder(ABC): skip_existing_in : `~collections.abc.Sequence` [ `str` ], optional Collections to search for outputs that already exist for the purpose of skipping quanta that have already been run. + retry_missing_outputs_for : `~collections.abc.Iterable` [ `str` ], optional + Task labels for which the task metadata dataset is not used as the + completion signal when ``skip_existing_in`` is provided. For these + tasks a quantum is skipped only when all of its regular outputs + are present in ``skip_existing_in``. This is useful for pipelines + where upstream processing does not retain all task outputs, so that + those tasks can be re-run to regenerate the missing intermediate + datasets. clobber : `bool`, optional Whether to raise if predicted outputs already exist in ``output_run`` (not including those quanta that would be skipped because they've @@ -171,6 +179,7 @@ def __init__( input_collections: Sequence[str] | None = None, output_run: str | None = None, skip_existing_in: Sequence[str] = (), + retry_missing_outputs_for: Iterable[str] = (), clobber: bool = False, ): self.log = getLogger(__name__) @@ -188,6 +197,7 @@ def __init__( self.butler = butler.clone(collections=input_collections) self.output_run = output_run self.skip_existing_in = skip_existing_in + self.retry_missing_outputs_for: frozenset[str] = frozenset(retry_missing_outputs_for) self.empty_data_id = DataCoordinate.make_empty(butler.dimensions) self.clobber = clobber # See whether the output run already exists. @@ -531,7 +541,7 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet # gotten rid of. skipped_quanta = [] for quantum_key in skeleton.get_quanta(task_node.label): - if self._skip_quantum_if_metadata_exists(task_node, quantum_key, skeleton): + if self._skip_quantum_if_done(task_node, quantum_key, skeleton): skipped_quanta.append(quantum_key) continue quantum_data_id = skeleton[quantum_key]["data_id"] @@ -699,11 +709,11 @@ def _get_task_inputs_if_overall_only(self, task_node: TaskNode) -> list[str] | N return None return result - def _skip_quantum_if_metadata_exists( + def _skip_quantum_if_done( self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton ) -> bool: """Identify and drop quanta that should be skipped because their - metadata datasets already exist. + metadata or output datasets already exist in ``skip_existing_in``. Parameters ---------- @@ -722,41 +732,67 @@ def _skip_quantum_if_metadata_exists( Notes ----- - If the metadata dataset for this quantum exists in the - `skip_existing_in` collections, the quantum will be skipped. This - causes the quantum node to be removed from the graph. Dataset nodes + For tasks not listed in `retry_missing_outputs_for`, a quantum is + skipped when its metadata dataset exists in ``skip_existing_in``. + + For tasks listed in `retry_missing_outputs_for`, the metadata dataset + is not used as the completion signal. Instead, a quantum is skipped + only when all of its regular outputs are present in + ``skip_existing_in``. If any such output is absent the quantum is + not skipped, so the task can regenerate the missing outputs. This + supports pipelines where upstream processing does not retain all task + outputs. + + The skipped quantum node is to be removed from the graph. Dataset nodes that were previously the outputs of this quantum will be associated with `lsst.daf.butler.DatasetRef` objects that were found in ``skip_existing_in``, or will be removed if there is no such dataset there. Any output dataset in `output_run` will be removed from the "output in the way" category. """ - metadata_dataset_key = DatasetKey( - task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values - ) - if skeleton.get_output_for_skip(metadata_dataset_key): - # This quantum's metadata is already present in the the - # skip_existing_in collections; we'll skip it. But the presence of - # the metadata dataset doesn't guarantee that all of the other - # outputs we predicted are present; we have to check. - for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)): - # If this dataset was "in the way" (i.e. already in the - # output run), it isn't anymore. - skeleton.discard_output_in_the_way(output_dataset_key) - if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None: - # Populate the skeleton graph's node attributes - # with the existing DatasetRef, just like a - # predicted output of a non-skipped quantum. - skeleton.set_dataset_ref(output_ref, output_dataset_key) - else: - # Remove this dataset from the skeleton graph, - # because the quantum that would have produced it - # is being skipped and it doesn't already exist. - skeleton.remove_dataset_nodes([output_dataset_key]) - # Removing the quantum node from the graph will happen outside this - # function. - return True - return False + metadata_name = task_node.metadata_output.parent_dataset_type_name + metadata_dataset_key = DatasetKey(metadata_name, quantum_key.data_id_values) + log_name = task_node.log_output.parent_dataset_type_name if task_node.log_output is not None else None + + if task_node.label in self.retry_missing_outputs_for: + # For this task, use actual output datasets as the completion + # signal rather than metadata. Skip only if all regular + # outputs are present in skip_existing_in; do not skip if + # any are absent so they can be regenerated. + regular_output_keys = [ + k + for k in skeleton.iter_outputs_of(quantum_key) + if k.parent_dataset_type_name != metadata_name and k.parent_dataset_type_name != log_name + ] + # No regular outputs is treated as not done. + if not regular_output_keys or any( + skeleton.get_output_for_skip(k) is None for k in regular_output_keys + ): + return False + # All regular outputs are present; fall through to skip. + elif not skeleton.get_output_for_skip(metadata_dataset_key): + # metadata absent: do not skip. + return False + + # We will skip the quantum. But it doesn't guarantee that all of the + # other outputs we predicted are present; we have to check. + for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)): + # If this dataset was "in the way" (i.e. already in the + # output run), it isn't anymore. + skeleton.discard_output_in_the_way(output_dataset_key) + if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None: + # Populate the skeleton graph's node attributes + # with the existing DatasetRef, just like a + # predicted output of a non-skipped quantum. + skeleton.set_dataset_ref(output_ref, output_dataset_key) + else: + # Remove this dataset from the skeleton graph, + # because the quantum that would have produced it + # is being skipped and it doesn't already exist. + skeleton.remove_dataset_nodes([output_dataset_key]) + # Removing the quantum node from the graph will happen outside this + # function. + return True @final def _update_quantum_for_adjust( diff --git a/python/lsst/pipe/base/separable_pipeline_executor.py b/python/lsst/pipe/base/separable_pipeline_executor.py index 52d694ac9..c4deba983 100644 --- a/python/lsst/pipe/base/separable_pipeline_executor.py +++ b/python/lsst/pipe/base/separable_pipeline_executor.py @@ -84,6 +84,11 @@ class SeparablePipelineExecutor: for existing outputs, and skips any quanta that have run to completion (or have no work to do). Otherwise, all tasks are attempted (subject to ``clobber_output``). + retry_missing_outputs_for : `~collections.abc.Iterable` [`str`], optional + Task labels for which the completion signal used by + ``skip_existing_in`` is changed from task metadata to all regular + outputs existing. Those tasks are re-run if any regular outputs are + absent. Has no effect without ``skip_existing_in``. task_factory : `.TaskFactory`, optional A custom task factory for use in pre-execution and execution. By default, a new instance of `.TaskFactory` is used. @@ -101,6 +106,7 @@ def __init__( butler: Butler, clobber_output: bool = False, skip_existing_in: Iterable[str] | None = None, + retry_missing_outputs_for: Iterable[str] | None = None, task_factory: TaskFactory | None = None, resources: ExecutionResources | None = None, raise_on_partial_outputs: bool = True, @@ -115,6 +121,7 @@ def __init__( self._clobber_output = clobber_output self._skip_existing_in = list(skip_existing_in) if skip_existing_in else [] + self._retry_missing_outputs_for = list(retry_missing_outputs_for) if retry_missing_outputs_for else [] self._task_factory = task_factory if task_factory else TaskFactory() self.resources = resources @@ -216,6 +223,7 @@ class are provided automatically (from explicit arguments to this pipeline.to_graph(), self._butler, skip_existing_in=self._skip_existing_in, + retry_missing_outputs_for=self._retry_missing_outputs_for, clobber=self._clobber_output, **kwargs, ) @@ -276,6 +284,7 @@ class are provided automatically (from explicit arguments to this "output_run": self._butler.run, "skip_existing_in": self._skip_existing_in, "skip_existing": bool(self._skip_existing_in), + "retry_missing_outputs_for": self._retry_missing_outputs_for, "data_query": where, "user": getpass.getuser(), "time": str(datetime.datetime.now()), @@ -344,6 +353,7 @@ class are provided automatically (from explicit arguments to this metadata = { "skip_existing_in": self._skip_existing_in, "skip_existing": bool(self._skip_existing_in), + "retry_missing_outputs_for": self._retry_missing_outputs_for, "data_query": where, } qg_builder = self.make_quantum_graph_builder(pipeline, where, builder_class=builder_class, **kwargs) diff --git a/tests/test_graphBuilder.py b/tests/test_graphBuilder.py index b5587b8da..198516e93 100644 --- a/tests/test_graphBuilder.py +++ b/tests/test_graphBuilder.py @@ -32,18 +32,20 @@ import unittest import lsst.utils.tests -from lsst.daf.butler import Butler, DatasetType +from lsst.daf.butler import Butler, DataCoordinate, DatasetType from lsst.daf.butler.registry import UserExpressionError from lsst.pipe.base import PipelineGraph, QuantumGraph from lsst.pipe.base.all_dimensions_quantum_graph_builder import ( AllDimensionsQuantumGraphBuilder, DatasetQueryConstraintVariant, ) +from lsst.pipe.base.quantum_graph_builder import OutputExistsError from lsst.pipe.base.tests import simpleQGraph from lsst.pipe.base.tests.mocks import ( DynamicConnectionConfig, DynamicTestPipelineTask, DynamicTestPipelineTaskConfig, + InMemoryRepo, MockDataset, MockStorageClass, ) @@ -228,6 +230,146 @@ def test_datastore_records(self): self.assertEqual(quantum.datastore_records, {}) +class SkipExistingInTestCase(unittest.TestCase): + """Tests for the skip_existing_in behavior of QuantumGraphBuilder.""" + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task() + self.helper.make_quantum_graph_builder(output_run="new_run") + self.helper.butler.collections.register("prior_run") + self._task_node = self.helper.pipeline_graph.tasks["task_auto1"] + self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions) + + def _insert(self, *names, run="prior_run"): + """Register datasets with empty data IDs into a run collection.""" + for name in names: + dt = self.helper.pipeline_graph.dataset_types[name].dataset_type + self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run) + + def _build(self, *, output_run="new_run", **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run=output_run, + **kwargs, + ).build(attach_datastore_records=False) + + def test_not_skipped_without_skip_existing_in(self): + """Without skip_existing_in, a quantum is never skipped even if + metadata exists in an input collection. + """ + self._insert(self._task_node.metadata_output.parent_dataset_type_name) + qgraph = self._build() + self.assertEqual(len(qgraph), 1) + + def test_skipped_when_metadata_exists(self): + """With skip_existing_in, a quantum is skipped when its metadata + dataset is present in the specified collections. + """ + self._insert(self._task_node.metadata_output.parent_dataset_type_name) + # Init-outputs required, otherwise InitInputMissingError. + for edge in self._task_node.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build(skip_existing_in=["prior_run"]) + self.assertEqual(len(qgraph), 0) + + def test_not_skipped_when_metadata_absent(self): + """With skip_existing_in, a quantum is not skipped when its metadata + dataset is absent from the specified collections. + """ + qgraph = self._build(skip_existing_in=["prior_run"]) + self.assertEqual(len(qgraph), 1) + + +class RetryMissingOutputsForTestCase(unittest.TestCase): + """Tests for QuantumGraphBuilder.retry_missing_outputs_for.""" + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task( + outputs={ + "out1": DynamicConnectionConfig(dataset_type_name="output_a"), + "out2": DynamicConnectionConfig(dataset_type_name="output_b"), + } + ) + self.helper.make_quantum_graph_builder(output_run="new_run") + self.helper.butler.collections.register("prior_run") + self._task_node = self.helper.pipeline_graph.tasks["task_auto1"] + self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions) + # Prior run wrote metadata but did not retain output datasets. + self._insert(self._task_node.metadata_output.parent_dataset_type_name) + + def _insert(self, *names, run="prior_run"): + """Register datasets with empty data IDs into a run collection.""" + for name in names: + dt = self.helper.pipeline_graph.dataset_types[name].dataset_type + self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run) + + def _build(self, *, output_run="new_run", **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run=output_run, + **kwargs, + ).build(attach_datastore_records=False) + + def test_not_skipped_when_outputs_missing(self): + """With retry_missing_outputs_for, quantum is not skipped when regular + outputs are absent from skip_existing_in, even if metadata is present. + + A scenario is that an upstream pipeline ran and wrote + metadata but did not retain output datasets. + """ + qgraph = self._build(skip_existing_in=["prior_run"], retry_missing_outputs_for=["task_auto1"]) + self.assertEqual(len(qgraph), 1) + + def test_skipped_when_all_outputs_present(self): + """With retry_missing_outputs_for, quantum is skipped when all regular + outputs are present in skip_existing_in. + """ + self._insert("output_a", "output_b") + # Init-outputs required when all quanta are skipped. + for edge in self._task_node.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build(skip_existing_in=["prior_run"], retry_missing_outputs_for=["task_auto1"]) + self.assertEqual(len(qgraph), 0) + + def test_output_exists_error_when_partial_outputs(self): + """With retry_missing_outputs_for, OutputExistsError is raised when + some outputs exist in the output run and clobber is off. + """ + self._insert("output_a") + # output_b absent -> not all outputs present -> task not skipped + + with self.assertRaises(OutputExistsError): + self._build( + skip_existing_in=["prior_run"], + retry_missing_outputs_for=["task_auto1"], + # Use the same run so that partial output is in the way. + output_run="prior_run", + ) + + def test_partial_outputs_clobber(self): + """With retry_missing_outputs_for and clobber=True, partial outputs + in the output run are discarded and the task runs. + """ + self._insert("output_a") + # output_b absent -> not all outputs present -> task not skipped + # clobber=True -> output_a discarded from graph, task runs + qgraph = self._build( + skip_existing_in=["prior_run"], + retry_missing_outputs_for=["task_auto1"], + output_run="prior_run", + clobber=True, + ) + self.assertEqual(len(qgraph), 1) + + if __name__ == "__main__": lsst.utils.tests.init() unittest.main() diff --git a/tests/test_separable_pipeline_executor.py b/tests/test_separable_pipeline_executor.py index 5f71a18f6..629b9373f 100644 --- a/tests/test_separable_pipeline_executor.py +++ b/tests/test_separable_pipeline_executor.py @@ -588,6 +588,69 @@ def test_make_quantum_graph_nowhere_skippartial_noclobber(self): with self.assertRaises(OutputExistsError): executor.build_quantum_graph(pipeline) + def test_make_quantum_graph_nowhere_retrymissing_not_skipped(self): + """With retry_missing_outputs_for, a task is not skipped when its + regular output is absent, even if metadata is present. + """ + prior_run = "prior_run" + self.butler.registry.registerCollection(prior_run, lsst.daf.butler.CollectionType.RUN) + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[prior_run], + retry_missing_outputs_for=["a"], + clobber_output=False, + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + + self.butler.put({"zero": 0}, "input") + # Metadata present in prior run but intermediate absent. + self.butler.put(TaskMetadata(), "a_metadata", run=prior_run) + + graph = executor.build_quantum_graph(pipeline) + self.assertEqual(len(graph), 2) + self.assertEqual(graph.quanta_by_task.keys(), {"a", "b"}) + + def test_make_quantum_graph_nowhere_retrymissing_skipped(self): + """With retry_missing_outputs_for, a task is skipped when its regular + output is present, even if metadata and log are absent. + """ + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[self.butler.run], + retry_missing_outputs_for=["a"], + clobber_output=False, + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + + self.butler.put({"zero": 0}, "input") + # Regular output present; metadata and log intentionally absent. + self.butler.put({"zero": 0}, "intermediate") + self.butler.put(lsst.pex.config.Config(), "a_config") + + graph = executor.build_quantum_graph(pipeline) + self.assertEqual(len(graph), 1) + self.assertEqual(graph.header.n_task_quanta["a"], 0) + self.assertEqual(graph.header.n_task_quanta["b"], 1) + + def test_make_quantum_graph_nowhere_retrymissing_metainway_noclobber(self): + """With retry_missing_outputs_for, OutputExistsError is raised when + metadata is already in the output run and the task is not skipped. + """ + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[self.butler.run], + retry_missing_outputs_for=["a"], + clobber_output=False, + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + + self.butler.put({"zero": 0}, "input") + self.butler.put(TaskMetadata(), "a_metadata") + # Metadata in output run, intermediate absent -> task not skipped + # Same output run, a_metadata is in the way. + with self.assertRaises(OutputExistsError): + executor.build_quantum_graph(pipeline) + def test_build_quantum_graph_nowhere_noskip_clobber(self): executor = SeparablePipelineExecutor(self.butler, skip_existing_in=None, clobber_output=True) pipeline = Pipeline.fromFile(self.pipeline_file)