Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-54879.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `retry_missing_outputs_for` parameter to `QuantumGraphBuilder` and `SeparablePipelineExecutor`.
100 changes: 68 additions & 32 deletions python/lsst/pipe/base/quantum_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ class QuantumGraphBuilder(ABC):
skip_existing_in : `~collections.abc.Sequence` [ `str` ], optional
Collections to search for outputs that already exist for the purpose of
skipping quanta that have already been run.
retry_missing_outputs_for : `~collections.abc.Iterable` [ `str` ], optional
Task labels for which the task metadata dataset is not used as the
completion signal when ``skip_existing_in`` is provided. For these
tasks a quantum is skipped only when all of its regular outputs
are present in ``skip_existing_in``. This is useful for pipelines
where upstream processing does not retain all task outputs, so that
those tasks can be re-run to regenerate the missing intermediate
datasets.
clobber : `bool`, optional
Whether to raise if predicted outputs already exist in ``output_run``
(not including those quanta that would be skipped because they've
Expand Down Expand Up @@ -171,6 +179,7 @@ def __init__(
input_collections: Sequence[str] | None = None,
output_run: str | None = None,
skip_existing_in: Sequence[str] = (),
retry_missing_outputs_for: Iterable[str] = (),
clobber: bool = False,
):
self.log = getLogger(__name__)
Expand All @@ -188,6 +197,7 @@ def __init__(
self.butler = butler.clone(collections=input_collections)
self.output_run = output_run
self.skip_existing_in = skip_existing_in
self.retry_missing_outputs_for: frozenset[str] = frozenset(retry_missing_outputs_for)
self.empty_data_id = DataCoordinate.make_empty(butler.dimensions)
self.clobber = clobber
# See whether the output run already exists.
Expand Down Expand Up @@ -531,7 +541,7 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet
# gotten rid of.
skipped_quanta = []
for quantum_key in skeleton.get_quanta(task_node.label):
if self._skip_quantum_if_metadata_exists(task_node, quantum_key, skeleton):
if self._skip_quantum_if_done(task_node, quantum_key, skeleton):
skipped_quanta.append(quantum_key)
continue
quantum_data_id = skeleton[quantum_key]["data_id"]
Expand Down Expand Up @@ -699,11 +709,11 @@ def _get_task_inputs_if_overall_only(self, task_node: TaskNode) -> list[str] | N
return None
return result

def _skip_quantum_if_metadata_exists(
def _skip_quantum_if_done(
self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton
) -> bool:
"""Identify and drop quanta that should be skipped because their
metadata datasets already exist.
metadata or output datasets already exist in ``skip_existing_in``.

Parameters
----------
Expand All @@ -722,41 +732,67 @@ def _skip_quantum_if_metadata_exists(

Notes
-----
If the metadata dataset for this quantum exists in the
`skip_existing_in` collections, the quantum will be skipped. This
causes the quantum node to be removed from the graph. Dataset nodes
For tasks not listed in `retry_missing_outputs_for`, a quantum is
skipped when its metadata dataset exists in ``skip_existing_in``.

For tasks listed in `retry_missing_outputs_for`, the metadata dataset
is not used as the completion signal. Instead, a quantum is skipped
only when all of its regular outputs are present in
``skip_existing_in``. If any such output is absent the quantum is
not skipped, so the task can regenerate the missing outputs. This
supports pipelines where upstream processing does not retain all task
outputs.

The skipped quantum node is to be removed from the graph. Dataset nodes
that were previously the outputs of this quantum will be associated
with `lsst.daf.butler.DatasetRef` objects that were found in
``skip_existing_in``, or will be removed if there is no such dataset
there. Any output dataset in `output_run` will be removed from the
"output in the way" category.
"""
metadata_dataset_key = DatasetKey(
task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values
)
if skeleton.get_output_for_skip(metadata_dataset_key):
# This quantum's metadata is already present in the the
# skip_existing_in collections; we'll skip it. But the presence of
# the metadata dataset doesn't guarantee that all of the other
# outputs we predicted are present; we have to check.
for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)):
# If this dataset was "in the way" (i.e. already in the
# output run), it isn't anymore.
skeleton.discard_output_in_the_way(output_dataset_key)
if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None:
# Populate the skeleton graph's node attributes
# with the existing DatasetRef, just like a
# predicted output of a non-skipped quantum.
skeleton.set_dataset_ref(output_ref, output_dataset_key)
else:
# Remove this dataset from the skeleton graph,
# because the quantum that would have produced it
# is being skipped and it doesn't already exist.
skeleton.remove_dataset_nodes([output_dataset_key])
# Removing the quantum node from the graph will happen outside this
# function.
return True
return False
metadata_name = task_node.metadata_output.parent_dataset_type_name
metadata_dataset_key = DatasetKey(metadata_name, quantum_key.data_id_values)
log_name = task_node.log_output.parent_dataset_type_name if task_node.log_output is not None else None

if task_node.label in self.retry_missing_outputs_for:
# For this task, use actual output datasets as the completion
# signal rather than metadata. Skip only if all regular
# outputs are present in skip_existing_in; do not skip if
# any are absent so they can be regenerated.
regular_output_keys = [
k
for k in skeleton.iter_outputs_of(quantum_key)
if k.parent_dataset_type_name != metadata_name and k.parent_dataset_type_name != log_name
]
# No regular outputs is treated as not done.
if not regular_output_keys or any(
skeleton.get_output_for_skip(k) is None for k in regular_output_keys
):
return False
# All regular outputs are present; fall through to skip.
elif not skeleton.get_output_for_skip(metadata_dataset_key):
# metadata absent: do not skip.
return False

# We will skip the quantum. But it doesn't guarantee that all of the
# other outputs we predicted are present; we have to check.
for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)):
# If this dataset was "in the way" (i.e. already in the
# output run), it isn't anymore.
skeleton.discard_output_in_the_way(output_dataset_key)
if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None:
# Populate the skeleton graph's node attributes
# with the existing DatasetRef, just like a
# predicted output of a non-skipped quantum.
skeleton.set_dataset_ref(output_ref, output_dataset_key)
else:
# Remove this dataset from the skeleton graph,
# because the quantum that would have produced it
# is being skipped and it doesn't already exist.
skeleton.remove_dataset_nodes([output_dataset_key])
# Removing the quantum node from the graph will happen outside this
# function.
return True

@final
def _update_quantum_for_adjust(
Expand Down
10 changes: 10 additions & 0 deletions python/lsst/pipe/base/separable_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class SeparablePipelineExecutor:
for existing outputs, and skips any quanta that have run to completion
(or have no work to do). Otherwise, all tasks are attempted (subject to
``clobber_output``).
retry_missing_outputs_for : `~collections.abc.Iterable` [`str`], optional
Task labels for which the completion signal used by
``skip_existing_in`` is changed from task metadata to all regular
outputs existing. Those tasks are re-run if any regular outputs are
absent. Has no effect without ``skip_existing_in``.
task_factory : `.TaskFactory`, optional
A custom task factory for use in pre-execution and execution. By
default, a new instance of `.TaskFactory` is used.
Expand All @@ -101,6 +106,7 @@ def __init__(
butler: Butler,
clobber_output: bool = False,
skip_existing_in: Iterable[str] | None = None,
retry_missing_outputs_for: Iterable[str] | None = None,
task_factory: TaskFactory | None = None,
resources: ExecutionResources | None = None,
raise_on_partial_outputs: bool = True,
Expand All @@ -115,6 +121,7 @@ def __init__(

self._clobber_output = clobber_output
self._skip_existing_in = list(skip_existing_in) if skip_existing_in else []
self._retry_missing_outputs_for = list(retry_missing_outputs_for) if retry_missing_outputs_for else []

self._task_factory = task_factory if task_factory else TaskFactory()
self.resources = resources
Expand Down Expand Up @@ -216,6 +223,7 @@ class are provided automatically (from explicit arguments to this
pipeline.to_graph(),
self._butler,
skip_existing_in=self._skip_existing_in,
retry_missing_outputs_for=self._retry_missing_outputs_for,
clobber=self._clobber_output,
**kwargs,
)
Expand Down Expand Up @@ -276,6 +284,7 @@ class are provided automatically (from explicit arguments to this
"output_run": self._butler.run,
"skip_existing_in": self._skip_existing_in,
"skip_existing": bool(self._skip_existing_in),
"retry_missing_outputs_for": self._retry_missing_outputs_for,
"data_query": where,
"user": getpass.getuser(),
"time": str(datetime.datetime.now()),
Expand Down Expand Up @@ -344,6 +353,7 @@ class are provided automatically (from explicit arguments to this
metadata = {
"skip_existing_in": self._skip_existing_in,
"skip_existing": bool(self._skip_existing_in),
"retry_missing_outputs_for": self._retry_missing_outputs_for,
"data_query": where,
}
qg_builder = self.make_quantum_graph_builder(pipeline, where, builder_class=builder_class, **kwargs)
Expand Down
144 changes: 143 additions & 1 deletion tests/test_graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@
import unittest

import lsst.utils.tests
from lsst.daf.butler import Butler, DatasetType
from lsst.daf.butler import Butler, DataCoordinate, DatasetType
from lsst.daf.butler.registry import UserExpressionError
from lsst.pipe.base import PipelineGraph, QuantumGraph
from lsst.pipe.base.all_dimensions_quantum_graph_builder import (
AllDimensionsQuantumGraphBuilder,
DatasetQueryConstraintVariant,
)
from lsst.pipe.base.quantum_graph_builder import OutputExistsError
from lsst.pipe.base.tests import simpleQGraph
from lsst.pipe.base.tests.mocks import (
DynamicConnectionConfig,
DynamicTestPipelineTask,
DynamicTestPipelineTaskConfig,
InMemoryRepo,
MockDataset,
MockStorageClass,
)
Expand Down Expand Up @@ -228,6 +230,146 @@ def test_datastore_records(self):
self.assertEqual(quantum.datastore_records, {})


class SkipExistingInTestCase(unittest.TestCase):
"""Tests for the skip_existing_in behavior of QuantumGraphBuilder."""

def setUp(self):
self.helper = InMemoryRepo()
self.enterContext(self.helper)
self.helper.add_task()
self.helper.make_quantum_graph_builder(output_run="new_run")
self.helper.butler.collections.register("prior_run")
self._task_node = self.helper.pipeline_graph.tasks["task_auto1"]
self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions)

def _insert(self, *names, run="prior_run"):
"""Register datasets with empty data IDs into a run collection."""
for name in names:
dt = self.helper.pipeline_graph.dataset_types[name].dataset_type
self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run)

def _build(self, *, output_run="new_run", **kwargs):
return AllDimensionsQuantumGraphBuilder(
self.helper.pipeline_graph,
self.helper.butler,
input_collections=[self.helper.input_chain],
output_run=output_run,
**kwargs,
).build(attach_datastore_records=False)

def test_not_skipped_without_skip_existing_in(self):
"""Without skip_existing_in, a quantum is never skipped even if
metadata exists in an input collection.
"""
self._insert(self._task_node.metadata_output.parent_dataset_type_name)
qgraph = self._build()
self.assertEqual(len(qgraph), 1)

def test_skipped_when_metadata_exists(self):
"""With skip_existing_in, a quantum is skipped when its metadata
dataset is present in the specified collections.
"""
self._insert(self._task_node.metadata_output.parent_dataset_type_name)
# Init-outputs required, otherwise InitInputMissingError.
for edge in self._task_node.init.iter_all_outputs():
self._insert(edge.parent_dataset_type_name)
qgraph = self._build(skip_existing_in=["prior_run"])
self.assertEqual(len(qgraph), 0)

def test_not_skipped_when_metadata_absent(self):
"""With skip_existing_in, a quantum is not skipped when its metadata
dataset is absent from the specified collections.
"""
qgraph = self._build(skip_existing_in=["prior_run"])
self.assertEqual(len(qgraph), 1)


class RetryMissingOutputsForTestCase(unittest.TestCase):
"""Tests for QuantumGraphBuilder.retry_missing_outputs_for."""

def setUp(self):
self.helper = InMemoryRepo()
self.enterContext(self.helper)
self.helper.add_task(
outputs={
"out1": DynamicConnectionConfig(dataset_type_name="output_a"),
"out2": DynamicConnectionConfig(dataset_type_name="output_b"),
}
)
self.helper.make_quantum_graph_builder(output_run="new_run")
self.helper.butler.collections.register("prior_run")
self._task_node = self.helper.pipeline_graph.tasks["task_auto1"]
self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions)
# Prior run wrote metadata but did not retain output datasets.
self._insert(self._task_node.metadata_output.parent_dataset_type_name)

def _insert(self, *names, run="prior_run"):
"""Register datasets with empty data IDs into a run collection."""
for name in names:
dt = self.helper.pipeline_graph.dataset_types[name].dataset_type
self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run)

def _build(self, *, output_run="new_run", **kwargs):
return AllDimensionsQuantumGraphBuilder(
self.helper.pipeline_graph,
self.helper.butler,
input_collections=[self.helper.input_chain],
output_run=output_run,
**kwargs,
).build(attach_datastore_records=False)

def test_not_skipped_when_outputs_missing(self):
"""With retry_missing_outputs_for, quantum is not skipped when regular
outputs are absent from skip_existing_in, even if metadata is present.

A scenario is that an upstream pipeline ran and wrote
metadata but did not retain output datasets.
"""
qgraph = self._build(skip_existing_in=["prior_run"], retry_missing_outputs_for=["task_auto1"])
self.assertEqual(len(qgraph), 1)

def test_skipped_when_all_outputs_present(self):
"""With retry_missing_outputs_for, quantum is skipped when all regular
outputs are present in skip_existing_in.
"""
self._insert("output_a", "output_b")
# Init-outputs required when all quanta are skipped.
for edge in self._task_node.init.iter_all_outputs():
self._insert(edge.parent_dataset_type_name)
qgraph = self._build(skip_existing_in=["prior_run"], retry_missing_outputs_for=["task_auto1"])
self.assertEqual(len(qgraph), 0)

def test_output_exists_error_when_partial_outputs(self):
"""With retry_missing_outputs_for, OutputExistsError is raised when
some outputs exist in the output run and clobber is off.
"""
self._insert("output_a")
# output_b absent -> not all outputs present -> task not skipped

with self.assertRaises(OutputExistsError):
self._build(
skip_existing_in=["prior_run"],
retry_missing_outputs_for=["task_auto1"],
# Use the same run so that partial output is in the way.
output_run="prior_run",
)

def test_partial_outputs_clobber(self):
"""With retry_missing_outputs_for and clobber=True, partial outputs
in the output run are discarded and the task runs.
"""
self._insert("output_a")
# output_b absent -> not all outputs present -> task not skipped
# clobber=True -> output_a discarded from graph, task runs
qgraph = self._build(
skip_existing_in=["prior_run"],
retry_missing_outputs_for=["task_auto1"],
output_run="prior_run",
clobber=True,
)
self.assertEqual(len(qgraph), 1)


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()
Loading
Loading