diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 8018ef648..09954b9b6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -5,13 +5,17 @@ import contextlib import functools +import json import logging import os import time import uuid +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +from pydantic import ValidationError + import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType @@ -22,6 +26,7 @@ ProcessorConfig, ProcessorType, ) +from data_designer.config.utils.type_helpers import StrEnum from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.version import get_library_version from data_designer.engine.column_generators.generators.base import ( @@ -53,7 +58,7 @@ from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider -from data_designer.engine.storage.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage +from data_designer.engine.storage.artifact_storage import SDG_CONFIG_FILENAME, ArtifactStorage, ResumeMode from data_designer.engine.storage.media_storage import StorageMode if TYPE_CHECKING: @@ -94,6 +99,21 @@ def _is_async_trace_enabled(settings: RunConfig) -> bool: return settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" +class _ConfigCompatibility(StrEnum): + COMPATIBLE = "compatible" + INCOMPATIBLE = "incompatible" + NO_PRIOR_DATASET = "no_prior_dataset" + + +@dataclass +class _ResumeState: + num_completed_batches: int + actual_num_records: int + buffer_size: int + target_num_records: int + original_target_num_records: int + + class DatasetBuilder: def __init__( self, @@ -197,6 +217,7 @@ def build( num_records: int, on_batch_complete: Callable[[Path], None] | None = None, save_multimedia_to_disk: bool = True, + resume: ResumeMode = ResumeMode.NEVER, ) -> Path: """Build the dataset. @@ -206,13 +227,61 @@ def build( save_multimedia_to_disk: Whether to save generated multimedia (images, audio, video) to disk. If False, multimedia is stored directly in the DataFrame (e.g., images as base64). Default is True. + resume: Controls how interrupted runs are handled. + + - ``ResumeMode.NEVER`` (default): always start a fresh generation run. + - ``ResumeMode.ALWAYS``: resume from the last completed batch (sync) or row group + (async). ``buffer_size`` must match the original run. ``num_records`` may be + equal to or greater than what was already generated (you can extend the dataset); + ``num_records`` less than actual records so far raises ``DatasetGenerationError``. + If no checkpoint exists yet (interrupted before the first batch finished), silently + restarts from the beginning. Raises if the stored config is incompatible. + - ``ResumeMode.IF_POSSIBLE``: like ``ALWAYS`` when the current config fingerprint + matches the stored config; otherwise starts a fresh run without raising an error. + + In all resume modes, in-flight partial results from the interrupted run are + discarded before generation continues. Returns: Path to the generated dataset directory. """ self._reset_run_state() + self._run_model_health_check_if_needed() self._run_mcp_tool_check_if_needed() + + # For IF_POSSIBLE and ALWAYS: check config compatibility before touching the artifact + # directory. _check_resume_config_compatibility() must NOT access base_dataset_path + # (which would cache resolved_dataset_name prematurely). After the decision, sync + # artifact_storage.resume so that resolved_dataset_name picks up the right semantics + # on its first real access. + # + # Also invalidate any stale resolved_dataset_name cache: ArtifactStorage's Pydantic + # validator accesses base_dataset_path at construction time, which caches resolved_dataset_name + # under the original resume mode semantics. Popping it forces a fresh resolution. + if resume in (ResumeMode.IF_POSSIBLE, ResumeMode.ALWAYS): + compat = self._check_resume_config_compatibility() + if resume == ResumeMode.ALWAYS and compat == _ConfigCompatibility.INCOMPATIBLE: + raise DatasetGenerationError( + "πŸ›‘ Cannot resume: the current config does not match the config used in the interrupted run. " + "Use resume=ResumeMode.IF_POSSIBLE to start fresh automatically, or " + "resume=ResumeMode.NEVER to force a new run." + ) + if resume == ResumeMode.IF_POSSIBLE: + if compat != _ConfigCompatibility.COMPATIBLE: + if compat == _ConfigCompatibility.INCOMPATIBLE: + logger.info( + "▢️ Config has changed since the last run β€” starting a fresh generation (resume=IF_POSSIBLE)." + ) + resume = ResumeMode.NEVER + self.artifact_storage.resume = ResumeMode.NEVER + self.artifact_storage.__dict__.pop("resolved_dataset_name", None) + self.artifact_storage.refresh_media_storage_path() + else: + resume = ResumeMode.ALWAYS + self.artifact_storage.resume = ResumeMode.ALWAYS + self.artifact_storage.__dict__.pop("resolved_dataset_name", None) + self._write_builder_config() # Set media storage mode based on parameters @@ -224,9 +293,24 @@ def build( start_time = time.perf_counter() buffer_size = self._resource_provider.run_config.buffer_size + if resume == ResumeMode.ALWAYS and not self.artifact_storage.metadata_file_path.exists(): + # No metadata.json means the previous run was interrupted before any batch (sync) or + # row group (async) completed. Nothing to resume β€” discard any leftover partial + # results and start fresh. + logger.info( + "▢️ No metadata.json found β€” the previous run was interrupted before any batch " + "completed. Starting generation from the beginning." + ) + self.artifact_storage.clear_partial_results() + resume = ResumeMode.NEVER + self.artifact_storage.resume = ResumeMode.NEVER + + generated = True self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility() if self._use_async: - self._build_async(generators, num_records, buffer_size, on_batch_complete) + generated = self._build_async(generators, num_records, buffer_size, on_batch_complete, resume=resume) + elif resume == ResumeMode.ALWAYS: + generated = self._build_with_resume(generators, num_records, buffer_size, on_batch_complete) else: group_id = uuid.uuid4().hex self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) @@ -241,11 +325,124 @@ def build( ) self.batch_manager.finish() - self._processor_runner.run_after_generation(buffer_size) + if generated: + self._processor_runner.run_after_generation(buffer_size) self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time) return self.artifact_storage.final_dataset_path + def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState: + """Read and validate resume state from an existing metadata.json. + + ``num_records`` must be >= the number of records already generated (you may extend + the dataset, but cannot shrink it below what has been written). ``buffer_size`` must + exactly match the original run because it determines row-group boundaries. + + Raises: + DatasetGenerationError: If metadata is missing or incompatible with the current run parameters. + """ + try: + metadata = self.artifact_storage.read_metadata() + except FileNotFoundError as exc: + raise DatasetGenerationError( + "πŸ›‘ Cannot resume: metadata.json not found in the existing dataset directory. " + "Run without resume=ResumeMode.ALWAYS to start a new generation." + ) from exc + + actual_num_records = metadata.get("actual_num_records", 0) + if num_records < actual_num_records: + raise DatasetGenerationError( + f"πŸ›‘ Cannot resume: num_records={num_records} is less than the {actual_num_records} " + "records already generated. Use num_records >= actual_num_records, " + "or start a new run without resume=ResumeMode.ALWAYS." + ) + + target_num_records = metadata.get("target_num_records") + if target_num_records is not None and num_records < target_num_records: + raise DatasetGenerationError( + f"πŸ›‘ Cannot resume: num_records={num_records} is less than the original target " + f"({target_num_records}). To resume, use num_records >= {target_num_records} " + "(you may extend the dataset beyond the original target). " + "Use resume=ResumeMode.NEVER to start a new run." + ) + + meta_buffer_size = metadata.get("buffer_size") + if meta_buffer_size != buffer_size: + raise DatasetGenerationError( + f"πŸ›‘ Cannot resume: buffer_size={buffer_size} does not match the original run's " + f"buffer_size={meta_buffer_size}. Use the same buffer_size as the interrupted run, " + "or start a new run without resume=ResumeMode.ALWAYS." + ) + + return _ResumeState( + num_completed_batches=metadata["num_completed_batches"], + actual_num_records=actual_num_records, + buffer_size=buffer_size, + target_num_records=metadata["target_num_records"], + original_target_num_records=metadata.get("original_target_num_records", metadata["target_num_records"]), + ) + + def _build_with_resume( + self, + generators: list[ColumnGenerator], + num_records: int, + buffer_size: int, + on_batch_complete: Callable[[Path], None] | None, + ) -> bool: + """Resume generation from the last completed batch. + + Returns: + False if the dataset was already complete (no new records generated), + True after successfully generating the remaining batches. + """ + state = self._load_resume_state(num_records, buffer_size) + + # Compute the correct per-batch sizes. ceil(num_records/bs) is wrong for a + # non-aligned extension: original groups are immutable, so any extension always + # adds new groups beyond num_original_batches. + original_target = state.original_target_num_records + num_original_batches = -(-original_target // buffer_size) + extension_records = num_records - original_target + num_extension_batches = -(-extension_records // buffer_size) + original_sizes = [min(buffer_size, original_target - i * buffer_size) for i in range(num_original_batches)] + extension_sizes = [min(buffer_size, extension_records - i * buffer_size) for i in range(num_extension_batches)] + + self.batch_manager.start( + num_records=num_records, + buffer_size=buffer_size, + start_batch=state.num_completed_batches, + initial_actual_num_records=state.actual_num_records, + num_records_list=original_sizes + extension_sizes, + original_target_num_records=original_target, + ) + + if state.num_completed_batches >= self.batch_manager.num_batches: + logger.warning( + "⚠️ Dataset is already complete β€” all batches were found in the existing artifact directory. " + "Nothing to resume. Use resume=ResumeMode.NEVER if you want to generate a new dataset." + ) + return False + + logger.info( + f"▢️ Resuming from batch {state.num_completed_batches + 1} of {self.batch_manager.num_batches} " + f"({state.actual_num_records} records already generated)." + ) + + self.artifact_storage.clear_partial_results() + + group_id = uuid.uuid4().hex + for batch_idx in range(state.num_completed_batches, self.batch_manager.num_batches): + logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") + self._run_batch( + generators, + batch_mode="batch", + group_id=group_id, + current_batch_number=batch_idx, + on_batch_complete=on_batch_complete, + ) + self.batch_manager.finish() + return True + def build_preview(self, *, num_records: int) -> pd.DataFrame: self._reset_run_state() self._run_model_health_check_if_needed() @@ -341,25 +538,141 @@ def _resolve_async_compatibility(self) -> bool: return False return True + def _find_completed_row_group_ids(self) -> set[int]: + """Scan the final dataset directory for already-written row group parquet files. + + Returns: + Set of row-group IDs (batch numbers) that have a parquet file in ``parquet-files/``. + """ + final_path = self.artifact_storage.final_dataset_path + if not final_path.exists(): + return set() + ids: set[int] = set() + for p in final_path.glob("batch_*.parquet"): + try: + ids.add(int(p.stem.split("_", 1)[1])) + except (ValueError, IndexError): + continue + return ids + + def _check_resume_config_compatibility(self) -> _ConfigCompatibility: + """Compare the current config fingerprint against the stored builder_config.json. + + Returns: + NO_PRIOR_DATASET β€” directory absent or empty (no prior run to resume from). + COMPATIBLE β€” fingerprints match, or stored config is unreadable (warning logged). + INCOMPATIBLE β€” fingerprints differ; continuing would mix records from two configs. + + Uses artifact_path / dataset_name directly β€” NOT base_dataset_path β€” to avoid + prematurely triggering the resolved_dataset_name cached_property before the + caller has had a chance to decide whether to resume or start fresh. + """ + dataset_dir = Path(self.artifact_storage.artifact_path) / self.artifact_storage.dataset_name + if not dataset_dir.exists() or not any(dataset_dir.iterdir()): + return _ConfigCompatibility.NO_PRIOR_DATASET + config_path = dataset_dir / SDG_CONFIG_FILENAME + if not config_path.exists(): + logger.warning( + "⚠️ No builder_config.json found in %s β€” skipping config compatibility check on resume.", + dataset_dir, + ) + return _ConfigCompatibility.COMPATIBLE + try: + stored_data = json.loads(config_path.read_text()) + stored_config = BuilderConfig.model_validate(stored_data) + current_fp = self._data_designer_config.fingerprint()["config_hash"] + stored_fp = stored_config.data_designer.fingerprint()["config_hash"] + return _ConfigCompatibility.COMPATIBLE if current_fp == stored_fp else _ConfigCompatibility.INCOMPATIBLE + except (OSError, json.JSONDecodeError, ValidationError): + logger.warning( + "⚠️ Could not read stored config at %s for compatibility check β€” assuming compatible.", + config_path, + ) + return _ConfigCompatibility.COMPATIBLE + def _build_async( self, generators: list[ColumnGenerator], num_records: int, buffer_size: int, on_batch_complete: Callable[[Path], None] | None = None, - ) -> None: - """Async task-queue builder path - dispatches tasks based on dependency readiness.""" + *, + resume: ResumeMode = ResumeMode.NEVER, + ) -> bool: + """Async task-queue builder path - dispatches tasks based on dependency readiness. + + Returns: + False if the dataset was already complete (no new records generated), + True after successfully running the scheduler. + """ logger.info("⚑ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue builder") settings = self._resource_provider.run_config trace_enabled = _is_async_trace_enabled(settings) + precomputed_row_groups: list[tuple[int, int]] | None = None + initial_actual_num_records = 0 + initial_total_num_batches = 0 + original_target = num_records # immutable original target; overridden on resume + + if resume == ResumeMode.ALWAYS: + state = self._load_resume_state(num_records, buffer_size) + completed_ids = self._find_completed_row_group_ids() + # Use filesystem as source of truth for both counters β€” metadata may lag by one + # row group if a crash occurred between move_partial_result_to_final_file_path + # and write_metadata. + # Use the original target (not the new num_records) so the last row group of a + # non-aligned run gets its true size, not buffer_size. + initial_total_num_batches = len(completed_ids) + original_target = state.original_target_num_records + + num_original_groups = -(-original_target // buffer_size) # ceil(original_target/buffer_size) + + def _rg_size(rg_id: int) -> int: + if rg_id < num_original_groups: + return min(buffer_size, original_target - rg_id * buffer_size) + ext_group_idx = rg_id - num_original_groups + return min(buffer_size, (num_records - original_target) - ext_group_idx * buffer_size) + + initial_actual_num_records = sum(_rg_size(rg_id) for rg_id in completed_ids) + self.artifact_storage.clear_partial_results() + + # Original groups are immutable; any extension always needs new groups beyond + # num_original_groups β€” ceil(num_records/bs) gives the wrong count when the + # original run was non-aligned and the extension fits in the last group's slack. + extension_records = num_records - original_target + total_row_groups = num_original_groups + -(-extension_records // buffer_size) + if len(completed_ids) >= total_row_groups: + logger.warning( + "⚠️ Dataset is already complete β€” all row groups were found in the existing artifact " + "directory. Nothing to resume. Use resume=ResumeMode.NEVER if you want to generate a new dataset." + ) + return False + + logger.info( + f"▢️ Resuming async run: {len(completed_ids)} of {total_row_groups} row group(s) already " + f"complete ({initial_actual_num_records} records), skipping them." + ) + + # Pre-compute the full row-group list with correct per-group sizes so that + # non-aligned skipped groups deduct their actual on-disk record count rather + # than buffer_size, keeping extension group sizes accurate. + precomputed_row_groups = [ + (rg_id, _rg_size(rg_id)) for rg_id in range(total_row_groups) if rg_id not in completed_ids + ] + def finalize_row_group(rg_id: int) -> None: def on_complete(final_path: Path | str | None) -> None: if final_path is not None and on_batch_complete: on_batch_complete(final_path) buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) + # Write incremental metadata after each row group so interrupted runs can be resumed. + buffer_manager.write_metadata( + target_num_records=num_records, + original_target_num_records=original_target, + buffer_size=buffer_size, + ) scheduler, buffer_manager = self._prepare_async_run( generators, @@ -370,6 +683,9 @@ def on_complete(final_path: Path | str | None) -> None: shutdown_error_window=settings.shutdown_error_window, disable_early_shutdown=settings.disable_early_shutdown, trace=trace_enabled, + precomputed_row_groups=precomputed_row_groups, + initial_actual_num_records=initial_actual_num_records, + initial_total_num_batches=initial_total_num_batches, ) # Telemetry snapshot @@ -398,8 +714,12 @@ def on_complete(final_path: Path | str | None) -> None: except Exception: logger.debug("Failed to emit batch telemetry for async run", exc_info=True) - # Write metadata - buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) + # Write final metadata (overwrites the last incremental write with identical content). + buffer_manager.write_metadata( + target_num_records=num_records, + original_target_num_records=original_target, + buffer_size=buffer_size, + ) # Surface partial completion actual = self._actual_num_records @@ -418,6 +738,8 @@ def on_complete(final_path: Path | str | None) -> None: else: logger.warning(base + "The dataset may be incomplete due to dropped rows.") + return True + def _prepare_async_run( self, generators: list[ColumnGenerator], @@ -430,6 +752,9 @@ def _prepare_async_run( shutdown_error_window: int = 10, disable_early_shutdown: bool = False, trace: bool = False, + precomputed_row_groups: list[tuple[int, int]] | None = None, + initial_actual_num_records: int = 0, + initial_total_num_batches: int = 0, ) -> tuple[AsyncTaskScheduler, RowGroupBufferManager]: """Build a fully-wired scheduler and buffer manager for async generation. @@ -452,18 +777,24 @@ def _prepare_async_run( for gen in generators: gen.log_pre_generation() - # Partition into row groups - row_groups: list[tuple[int, int]] = [] - remaining = num_records - rg_id = 0 - while remaining > 0: - size = min(buffer_size, remaining) - row_groups.append((rg_id, size)) - remaining -= size - rg_id += 1 + if precomputed_row_groups is not None: + row_groups = precomputed_row_groups + else: + row_groups = [] + remaining = num_records + rg_id = 0 + while remaining > 0: + size = min(buffer_size, remaining) + row_groups.append((rg_id, size)) + remaining -= size + rg_id += 1 tracker = CompletionTracker.with_graph(graph, row_groups) - buffer_manager = RowGroupBufferManager(self.artifact_storage) + buffer_manager = RowGroupBufferManager( + self.artifact_storage, + initial_actual_num_records=initial_actual_num_records, + initial_total_num_batches=initial_total_num_batches, + ) # Pre-batch processor callback: runs after seed tasks complete for a row group. # If it raises, the scheduler propagates the error as DatasetGenerationError (fail-fast). diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py index 757853300..f2bc39cdd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dataset_batch_manager.py @@ -26,6 +26,7 @@ def __init__(self, artifact_storage: ArtifactStorage): self._num_records_list: list[int] | None = None self._buffer_size: int | None = None self._actual_num_records: int = 0 + self._original_target_num_records: int | None = None self.artifact_storage = artifact_storage @property @@ -87,9 +88,11 @@ def finish_batch(self, on_complete: Callable[[Path], None] | None = None) -> Pat self._actual_num_records += len(self._buffer) final_file_path = self.artifact_storage.move_partial_result_to_final_file_path(self._current_batch_number) + target = sum(self.num_records_list) self.artifact_storage.write_metadata( { - "target_num_records": sum(self.num_records_list), + "target_num_records": target, + "original_target_num_records": self._original_target_num_records or target, "actual_num_records": self._actual_num_records, "total_num_batches": self.num_batches, "buffer_size": self._buffer_size, @@ -158,17 +161,32 @@ def reset(self, delete_files: bool = False) -> None: except OSError as e: raise DatasetBatchManagementError(f"πŸ›‘ Failed to delete directory {dir_path}: {e}") - def start(self, *, num_records: int, buffer_size: int) -> None: + def start( + self, + *, + num_records: int, + buffer_size: int, + start_batch: int = 0, + initial_actual_num_records: int = 0, + num_records_list: list[int] | None = None, + original_target_num_records: int | None = None, + ) -> None: if num_records <= 0: raise DatasetBatchManagementError("πŸ›‘ num_records must be positive.") if buffer_size <= 0: raise DatasetBatchManagementError("πŸ›‘ buffer_size must be positive.") self._buffer_size = buffer_size - self._num_records_list = [buffer_size] * (num_records // buffer_size) - if remaining_records := num_records % buffer_size: - self._num_records_list.append(remaining_records) + self._original_target_num_records = original_target_num_records + if num_records_list is not None: + self._num_records_list = list(num_records_list) + else: + self._num_records_list = [buffer_size] * (num_records // buffer_size) + if remaining_records := num_records % buffer_size: + self._num_records_list.append(remaining_records) self.reset() + self._current_batch_number = start_batch + self._actual_num_records = initial_actual_num_records def write(self) -> Path | None: """Write the current batch to a parquet file. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index 5ddbfec8c..b4b28e9aa 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -28,13 +28,18 @@ class RowGroupBufferManager: exclusively by the async scheduler. """ - def __init__(self, artifact_storage: ArtifactStorage) -> None: + def __init__( + self, + artifact_storage: ArtifactStorage, + initial_actual_num_records: int = 0, + initial_total_num_batches: int = 0, + ) -> None: self._buffers: dict[int, list[dict]] = {} self._row_group_sizes: dict[int, int] = {} self._dropped: dict[int, set[int]] = {} self._artifact_storage = artifact_storage - self._actual_num_records: int = 0 - self._total_num_batches: int = 0 + self._actual_num_records: int = initial_actual_num_records + self._total_num_batches: int = initial_total_num_batches def init_row_group(self, row_group: int, size: int) -> None: """Allocate a buffer for *row_group* with *size* empty rows.""" @@ -129,11 +134,14 @@ def checkpoint_row_group( self.free_row_group(row_group) - def write_metadata(self, target_num_records: int, buffer_size: int) -> None: + def write_metadata( + self, target_num_records: int, buffer_size: int, original_target_num_records: int | None = None + ) -> None: """Write final metadata after all row groups are checkpointed.""" self._artifact_storage.write_metadata( { "target_num_records": target_num_records, + "original_target_num_records": original_target_num_records or target_num_records, "actual_num_records": self._actual_num_records, "total_num_batches": self._total_num_batches, "buffer_size": buffer_size, diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/artifact_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/artifact_storage.py index 458eed689..cd106a991 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/artifact_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/artifact_storage.py @@ -38,6 +38,12 @@ class BatchStage(StrEnum): PROCESSORS_OUTPUTS = "processors_outputs_path" +class ResumeMode(StrEnum): + NEVER = "never" + ALWAYS = "always" + IF_POSSIBLE = "if_possible" + + class ArtifactStorage(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -47,6 +53,7 @@ class ArtifactStorage(BaseModel): partial_results_folder_name: str = "tmp-partial-parquet-files" dropped_columns_folder_name: str = "dropped-columns-parquet-files" processors_outputs_folder_name: str = PROCESSORS_OUTPUTS_FOLDER_NAME + resume: ResumeMode = ResumeMode.NEVER _media_storage: MediaStorage = PrivateAttr(default=None) @property @@ -67,12 +74,19 @@ def artifact_path_exists(self) -> bool: def resolved_dataset_name(self) -> str: dataset_path = self.artifact_path / self.dataset_name if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0: + if self.resume in (ResumeMode.ALWAYS, ResumeMode.IF_POSSIBLE): + return self.dataset_name new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}" logger.info( f"πŸ“‚ Dataset path {str(dataset_path)!r} already exists. Dataset from this session" f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead." ) return new_dataset_name + if self.resume == ResumeMode.ALWAYS: + raise ArtifactStorageError( + f"πŸ›‘ Cannot resume: no existing dataset found at {str(dataset_path)!r}. " + "Run without resume=ResumeMode.ALWAYS to start a new generation." + ) return self.dataset_name @property @@ -144,6 +158,16 @@ def set_media_storage_mode(self, mode: StorageMode) -> None: """ self._media_storage.mode = mode + def refresh_media_storage_path(self) -> None: + """Re-point MediaStorage to the current base_dataset_path. + + Must be called after popping the resolved_dataset_name cache so that + _media_storage.base_path and .images_dir reflect the updated directory. + """ + images_subdir = self._media_storage.images_dir.name + self._media_storage.base_path = self.base_dataset_path + self._media_storage.images_dir = self.base_dataset_path / images_subdir + @staticmethod def mkdir_if_needed(path: Path | str) -> Path: """Create the directory if it does not exist.""" @@ -204,6 +228,11 @@ def load_dataset_with_dropped_columns(self) -> pd.DataFrame: df = lazy.pd.concat([df, df_dropped], axis=1) return df + def clear_partial_results(self) -> None: + """Remove any in-flight partial results left over from an interrupted run.""" + if self.partial_results_path.exists(): + shutil.rmtree(self.partial_results_path) + def move_partial_result_to_final_file_path(self, batch_number: int) -> Path: partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT) if not partial_result_path.exists(): diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 8c796374a..efb7d87f6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -21,7 +21,7 @@ from data_designer.config.seed_source import LocalFileSeedSource from data_designer.config.seed_source_dataframe import DataFrameSeedSource from data_designer.engine.column_generators.generators.base import GenerationStrategy -from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder, _ConfigCompatibility from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.models.errors import ( FormattedLLMErrorMessage, @@ -33,6 +33,7 @@ from data_designer.engine.processing.processors.base import Processor from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.seed_reader import DataFrameSeedReader +from data_designer.engine.storage.artifact_storage import ResumeMode if TYPE_CHECKING: import pandas as pd @@ -1433,3 +1434,789 @@ def test_skip_row_count_preserved_across_pipeline(stub_resource_provider, stub_m assert len(result) == 5, "Skip must not change the row count" assert result["seed_id"].tolist() == [1, 2, 3, 4, 5] + + +# --------------------------------------------------------------------------- +# Resume mechanism tests +# --------------------------------------------------------------------------- + + +import json as _json +from pathlib import Path as _Path + +from data_designer.engine.storage.artifact_storage import ArtifactStorage as _ArtifactStorage + + +def _write_metadata(dataset_dir: _Path, **fields) -> None: + """Write a metadata.json into an existing dataset folder.""" + dataset_dir.mkdir(parents=True, exist_ok=True) + (dataset_dir / "sentinel.txt").write_text("x") # make folder non-empty for resolved_dataset_name + (dataset_dir / "metadata.json").write_text(_json.dumps(fields)) + + +def _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, *, buffer_size: int = 2): + """Return a DatasetBuilder whose ArtifactStorage has resume=ResumeMode.ALWAYS.""" + storage = _ArtifactStorage(artifact_path=tmp_path, resume=ResumeMode.ALWAYS) + stub_resource_provider.artifact_storage = storage + stub_resource_provider.run_config = RunConfig(buffer_size=buffer_size) + return DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + +def test_build_resume_starts_fresh_without_metadata(stub_resource_provider, stub_test_config_builder, tmp_path, caplog): + """resume=True when only the folder exists (no metadata.json) logs an info message and starts fresh. + + This covers the case where a run was interrupted before any batch completed β€” the + folder was created by _write_builder_config but metadata.json was never written. + Previously this raised DatasetGenerationError; now it silently restarts from batch 0. + """ + # Pre-create the folder with content so resolved_dataset_name(resume=True) returns "dataset" + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + (dataset_dir / "builder_config.json").write_text("{}") # non-empty, no metadata + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path) + with caplog.at_level(logging.INFO): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_run_batch"): + with patch.object(builder.batch_manager, "finish"): + # resume=False is set internally; build dispatches to the normal (non-resume) path + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + assert any("interrupted before any batch completed" in record.message for record in caplog.records) + + +def test_build_resume_raises_when_num_records_below_actual(stub_resource_provider, stub_test_config_builder, tmp_path): + """resume=ALWAYS raises when num_records is less than what has already been generated.""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=10, + buffer_size=2, + num_completed_batches=3, + actual_num_records=6, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with pytest.raises(DatasetGenerationError, match="num_records=4 is less than the 6 records already generated"): + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + +def test_build_resume_raises_when_num_records_below_original_target( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """resume=ALWAYS raises when num_records is between actual and original target (negative extension_records).""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=10, + buffer_size=2, + num_completed_batches=2, + actual_num_records=4, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with pytest.raises(DatasetGenerationError, match="num_records=7 is less than the original target"): + builder.build(num_records=7, resume=ResumeMode.ALWAYS) + + +def test_build_resume_allows_larger_num_records(stub_resource_provider, stub_test_config_builder, tmp_path, caplog): + """resume=ALWAYS succeeds when num_records > original target (extending the dataset).""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=4, + buffer_size=2, + num_completed_batches=2, + actual_num_records=4, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with caplog.at_level(logging.WARNING): + with patch.object(builder, "_run_model_health_check_if_needed"): + # 6 > 4 already generated β†’ not already complete, should start generating + # Here we just verify it does NOT raise on the num_records check + with patch.object(builder, "_build_with_resume", return_value=True): + builder.build(num_records=6, resume=ResumeMode.ALWAYS) + + +def test_build_resume_raises_on_buffer_size_mismatch(stub_resource_provider, stub_test_config_builder, tmp_path): + """resume=True raises when buffer_size differs from the original run.""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=4, + buffer_size=2, + num_completed_batches=1, + actual_num_records=2, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=3) + with pytest.raises(DatasetGenerationError, match="buffer_size=3 does not match"): + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + +def test_build_resume_always_raises_on_config_mismatch(stub_resource_provider, stub_test_config_builder, tmp_path): + """resume=ALWAYS raises DatasetGenerationError when the stored config fingerprint differs.""" + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=4, buffer_size=2, num_completed_batches=1, actual_num_records=2) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path) + with patch.object(builder, "_check_resume_config_compatibility", return_value=_ConfigCompatibility.INCOMPATIBLE): + with pytest.raises(DatasetGenerationError, match="does not match the config used"): + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + +def test_build_resume_logs_warning_when_already_complete( + stub_resource_provider, stub_test_config_builder, tmp_path, caplog +): + """resume=True on a fully-complete dataset logs a warning and returns without generating.""" + dataset_dir = tmp_path / "dataset" + # 4 records, 2 per batch = 2 batches; num_completed_batches == 2 β†’ already done + _write_metadata( + dataset_dir, + target_num_records=4, + buffer_size=2, + num_completed_batches=2, + actual_num_records=4, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with caplog.at_level(logging.WARNING): + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + assert any("already complete" in record.message for record in caplog.records) + + +def test_build_resume_already_complete_does_not_run_after_generation_processors( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """When already complete, run_after_generation must NOT be called (would destroy the dataset).""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=4, + buffer_size=2, + num_completed_batches=2, + actual_num_records=4, + ) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with patch.object(builder._processor_runner, "run_after_generation") as mock_after: + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + mock_after.assert_not_called() + + +def test_build_resume_not_already_complete_when_extension_fits_in_slack( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """Non-aligned extension fitting in the last group's slack must not falsely trigger 'already complete'. + + original_target=5, buffer_size=2 β†’ 3 batches [2,2,1]; extending to num_records=6: + ceil(6/2)=3 == num_completed_batches=3 used to trigger the false 'already complete' branch. + Correct total_batches = 3 + ceil(1/2) = 4, so batch 3 (1 record) must be scheduled. + """ + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=5, buffer_size=2, num_completed_batches=3, actual_num_records=5) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + with patch.object(builder, "_run_batch") as mock_run_batch: + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder, "_run_model_health_check_if_needed"): + builder.build(num_records=6, resume=ResumeMode.ALWAYS) + + mock_run_batch.assert_called_once() + assert mock_run_batch.call_args.kwargs["current_batch_number"] == 3 + + +# --------------------------------------------------------------------------- +# Async resume via _build_async tests +# --------------------------------------------------------------------------- + + +def _write_parquet_files(parquet_dir: _Path, row_group_ids: list[int]) -> None: + """Create stub batch_*.parquet files for the given row group IDs.""" + parquet_dir.mkdir(parents=True, exist_ok=True) + for rg_id in row_group_ids: + (parquet_dir / f"batch_{rg_id:05d}.parquet").write_text("") + + +def test_build_async_resume_logs_warning_when_already_complete( + stub_resource_provider, stub_test_config_builder, tmp_path, caplog +): + """Async resume on a fully-complete dataset logs a warning and returns without running.""" + dataset_dir = tmp_path / "dataset" + # 4 records at buffer_size=2 β†’ 2 row groups (IDs 0 and 1) + _write_metadata(dataset_dir, target_num_records=4, buffer_size=2, num_completed_batches=2, actual_num_records=4) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + with caplog.at_level(logging.WARNING): + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder, "_run_model_health_check_if_needed"): + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + assert any("already complete" in record.message for record in caplog.records) + + +def test_build_async_resume_starts_fresh_without_metadata( + stub_resource_provider, stub_test_config_builder, tmp_path, caplog +): + """Async resume with no metadata.json logs an info message and starts fresh. + + Previously this raised DatasetGenerationError; now it silently restarts from row group 0. + The log is emitted in build() before dispatching to _build_async, so mocking _build_async + does not suppress the message. + """ + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + (dataset_dir / "builder_config.json").write_text("{}") + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path) + + with caplog.at_level(logging.INFO): + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_build_async", return_value=True) as mock_async: + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + # _build_async is called with resume=NEVER because the no-metadata path resets the mode + _, kwargs = mock_async.call_args + assert kwargs.get("resume") == ResumeMode.NEVER + assert any("interrupted before any batch completed" in record.message for record in caplog.records) + + +def test_build_async_resume_already_complete_does_not_run_after_generation_processors( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """Async resume: when already complete, run_after_generation must NOT be called.""" + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=4, buffer_size=2, num_completed_batches=2, actual_num_records=4) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder._processor_runner, "run_after_generation") as mock_after: + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + mock_after.assert_not_called() + + +def test_find_completed_row_group_ids_used_for_initial_total_batches( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """initial_total_num_batches uses filesystem count, not metadata count. + + Simulates the crash window: 2 parquet files exist on disk but metadata still + records num_completed_batches=1 (write_metadata crashed after the second + row group was moved to parquet-files/ but before metadata was updated). + Verifies that _find_completed_row_group_ids() (= 2) is used, not metadata (= 1). + """ + dataset_dir = tmp_path / "dataset" + # Metadata lags β€” says only 1 batch completed + _write_metadata(dataset_dir, target_num_records=4, buffer_size=2, num_completed_batches=1, actual_num_records=2) + # Filesystem truth β€” 2 row groups already written + _write_parquet_files(dataset_dir / "parquet-files", [0, 1]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + # Both row groups are on disk β†’ dataset is already complete β†’ generated=False + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder._processor_runner, "run_after_generation") as mock_after: + builder.build(num_records=4, resume=ResumeMode.ALWAYS) + + # Already complete based on filesystem count (2 files β‰₯ 2 row groups) β€” no generation needed + mock_after.assert_not_called() + + +def test_initial_actual_num_records_from_filesystem_in_crash_window( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """initial_actual_num_records is derived from filesystem, not stale metadata. + + Crash window scenario: row groups 0 and 1 are on disk but metadata only records + num_completed_batches=1 / actual_num_records=2 (write_metadata crashed after + the second row group was written but before it updated the file). + + With 6 records and buffer_size=2 (3 row groups total), the correct + initial_actual_num_records is 4 (groups 0+1), not 2 (stale metadata value). + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + # Metadata lags β€” says only 1 batch completed with 2 records + _write_metadata(dataset_dir, target_num_records=6, buffer_size=2, num_completed_batches=1, actual_num_records=2) + # Filesystem truth β€” 2 row groups already written (ids 0 and 1) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["initial_actual_num_records"] = kwargs.get("initial_actual_num_records", 0) + captured["initial_total_num_batches"] = kwargs.get("initial_total_num_batches", 0) + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 6 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + # asyncio and ensure_async_engine_loop are lazy-imported in dataset_builder only when + # DATA_DESIGNER_ASYNC_ENGINE=True at module load time. Inject them for the duration + # of this test so _build_async can proceed past the early-return path. + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + builder.build(num_records=6, resume=ResumeMode.ALWAYS) + + # Filesystem says 2 groups done (IDs 0+1) β†’ 2+2 = 4 records, not stale metadata value 2 + assert captured["initial_actual_num_records"] == 4 + assert captured["initial_total_num_batches"] == 2 + + +def test_build_async_resume_initial_actual_num_records_uses_original_target( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """initial_actual_num_records uses the original target_num_records, not the new num_records. + + When extending a non-aligned run (original num_records=5, buffer_size=2 β†’ row groups [2,2,1]), + all 3 row groups completed. Resuming with num_records=7 must not use the new target in the + formula: min(2, 7-2*2)=min(2,3)=2 would give 6, but the actual data is 5 records. + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + # Original run: 5 records, buffer_size=2, all 3 row groups done + _write_metadata(dataset_dir, target_num_records=5, buffer_size=2, num_completed_batches=3, actual_num_records=5) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1, 2]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["initial_actual_num_records"] = kwargs.get("initial_actual_num_records", 0) + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 7 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + # Extend the dataset: new target is 7, original was 5 + builder.build(num_records=7, resume=ResumeMode.ALWAYS) + + # Row groups [2, 2, 1] from original 5-record run: 2+2+1=5, not 2+2+2=6 + assert captured["initial_actual_num_records"] == 5 + + +def test_build_async_resume_initial_actual_num_records_extension_crash_window( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """Extension row groups on disk use new num_records in the size formula, not original target. + + Crash window: original run had num_records=5, buffer_size=2 (row groups [2,2,1], all done). + Extension starts with num_records=9; row group 3 (2 records) is written to disk but + write_metadata crashes before updating the file. On resume, completed_ids={0,1,2,3} + while metadata still reports target_num_records=5. + + Correct count: groups 0,1 β†’ 2+2; group 2 (last original, non-aligned) β†’ 1; group 3 + (extension) β†’ min(2, 9-6)=2. Total = 7, not 4 (which the unguarded formula gives, + since min(2, 5-6) = -1). + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=5, buffer_size=2, num_completed_batches=3, actual_num_records=5) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1, 2, 3]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["initial_actual_num_records"] = kwargs.get("initial_actual_num_records", 0) + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 9 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + builder.build(num_records=9, resume=ResumeMode.ALWAYS) + + # 2+2+1 (original) + 2 (extension group 3) = 7, not 4 (which unguarded formula gives) + assert captured["initial_actual_num_records"] == 7 + + +def test_build_async_resume_stale_original_target_after_incremental_metadata_write( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """original_target_num_records stays immutable even after an incremental metadata write. + + Scenario: original run had num_records=5, buffer_size=2 (row groups [2,2,1], all done). + Extension to num_records=9 starts; row group 3 (2 records) completes and finalize_row_group + writes metadata with target_num_records=9. Crash before row group 4. + + On second resume, metadata now shows target_num_records=9. Without the fix, original_target + would be read as 9, making num_original_groups=5 and producing wrong _rg_size values. + With the fix, original_target_num_records=5 is preserved in metadata, giving the correct + initial_actual_num_records=7 (2+2+1 original + 2 extension). + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + # Metadata reflects a post-incremental-write state: target updated to 9, original still 5 + _write_metadata( + dataset_dir, + target_num_records=9, + original_target_num_records=5, + buffer_size=2, + num_completed_batches=4, + actual_num_records=7, + ) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1, 2, 3]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["initial_actual_num_records"] = kwargs.get("initial_actual_num_records", 0) + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 9 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + builder.build(num_records=9, resume=ResumeMode.ALWAYS) + + # original_target=5 β†’ groups 0,1 β†’ 2+2; group 2 β†’ 1; group 3 (ext) β†’ min(2,9-6)=2. Total=7 + assert captured["initial_actual_num_records"] == 7 + + +def test_build_async_resume_skip_row_groups_contains_completed_ids( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """precomputed_row_groups passed to _prepare_async_run excludes already-completed row groups. + + Verifies the skip mechanism so the scheduler never re-generates a row group that + already has a parquet file on disk. 6 records, buffer_size=2 β†’ 3 row groups total; + row groups 0 and 2 already on disk β†’ only row group 1 should be scheduled. + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + # 6 records, buffer_size=2 β†’ 3 row groups total; row groups 0 and 2 already on disk + _write_metadata(dataset_dir, target_num_records=6, buffer_size=2, num_completed_batches=2, actual_num_records=4) + _write_parquet_files(dataset_dir / "parquet-files", [0, 2]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["precomputed_row_groups"] = kwargs.get("precomputed_row_groups") + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 6 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + builder.build(num_records=6, resume=ResumeMode.ALWAYS) + + # Only rg_id=1 remains; rg_id=0 and rg_id=2 are already on disk + assert captured["precomputed_row_groups"] == [(1, 2)] + + +def test_build_async_resume_extension_non_aligned_row_group_sizes( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """Extension row groups get the correct size when the original run was non-aligned. + + Original run: num_records=5, buffer_size=2 β†’ row groups [2, 2, 1], all completed. + Extending to num_records=7: the loop previously deducted 2 for rg_id=2 (instead of 1), + leaving remaining=1 so rg_id=3 received size 1 instead of 2. 7 records were never + generated; only 6 reached the dataset and a false partial-completion warning fired. + + After the fix, precomputed_row_groups must be [(3, 2)], not [(3, 1)]. + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=5, buffer_size=2, num_completed_batches=3, actual_num_records=5) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1, 2]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + captured: dict = {} + + def capturing_prepare(*args, **kwargs): + captured["precomputed_row_groups"] = kwargs.get("precomputed_row_groups") + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 7 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare): + builder.build(num_records=7, resume=ResumeMode.ALWAYS) + + # rg_id=3 should have 2 records (7-5=2 extension records, buffer_size=2), not 1 + assert captured["precomputed_row_groups"] == [(3, 2)] + + +def test_build_async_resume_not_already_complete_when_extension_fits_in_slack( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """Non-aligned extension fitting in the last group's slack must not falsely trigger 'already complete'. + + original_target=5, buffer_size=2 β†’ 3 row groups; extending to num_records=6: + ceil(6/2)=3 == len(completed_ids)=3 used to trigger the false 'already complete' branch. + Correct total_row_groups = 3 + ceil(1/2) = 4, so _prepare_async_run must be called. + """ + import asyncio as stdlib_asyncio + + dataset_dir = tmp_path / "dataset" + _write_metadata(dataset_dir, target_num_records=5, buffer_size=2, num_completed_batches=3, actual_num_records=5) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1, 2]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + + def capturing_prepare(*args, **kwargs): + mock_scheduler = Mock() + mock_scheduler.traces = [] + mock_buffer_manager = Mock() + mock_buffer_manager.actual_num_records = 6 + return mock_scheduler, mock_buffer_manager + + mock_future = Mock() + mock_future.result = Mock(return_value=None) + + with patch.object(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE", True): + with patch.object(builder_mod, "asyncio", stdlib_asyncio, create=True): + with patch.object(builder_mod, "ensure_async_engine_loop", Mock(return_value=Mock()), create=True): + with patch.object(stdlib_asyncio, "run_coroutine_threadsafe", return_value=mock_future): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_prepare_async_run", side_effect=capturing_prepare) as mock_prepare: + builder.build(num_records=6, resume=ResumeMode.ALWAYS) + + # _prepare_async_run must be called β€” the dataset is NOT already complete + mock_prepare.assert_called_once() + + +def test_if_possible_incompatible_config_does_not_overwrite_existing_dataset( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """IF_POSSIBLE + incompatible config must NOT resolve to the existing dataset directory. + + Bug: _check_resume_config_compatibility() used base_dataset_path, triggering the + resolved_dataset_name cached_property while artifact_storage.resume was still IF_POSSIBLE. + The property cached the existing directory name; after resume was reset to NEVER locally, + artifact_storage.resume was never updated, so _write_builder_config() still wrote into the + old directory. + + Fix: _check_resume_config_compatibility() uses artifact_path/dataset_name directly and + build() syncs artifact_storage.resume = NEVER before the first real access to base_dataset_path. + """ + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + sentinel = dataset_dir / "important_file.txt" + sentinel.write_text("precious data") + + storage = _ArtifactStorage(artifact_path=tmp_path, resume=ResumeMode.IF_POSSIBLE) + stub_resource_provider.artifact_storage = storage + + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + # Simulate incompatible config and mock out all I/O so build() does not actually generate data + with patch.object(builder, "_check_resume_config_compatibility", return_value=_ConfigCompatibility.INCOMPATIBLE): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_run_mcp_tool_check_if_needed"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + + # artifact_storage.resume must be downgraded to NEVER so resolved_dataset_name uses NEVER semantics + assert storage.resume == ResumeMode.NEVER + + # resolved_dataset_name has not been cached yet (compat check bypassed base_dataset_path, + # _write_builder_config was mocked). Accessing it now must give a timestamped name. + assert sentinel.exists(), "Existing dataset directory must not be touched" + assert storage.resolved_dataset_name != "dataset", ( + "resolved_dataset_name must be a new timestamped directory, not the existing one" + ) + + +def test_if_possible_incompatible_config_refreshes_media_storage_path( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """After IF_POSSIBLE β†’ NEVER downgrade, _media_storage must point to the new timestamped dir. + + Bug: validate_folder_names initialises MediaStorage with base_dataset_path at Pydantic + construction time (while resume=IF_POSSIBLE), caching the original directory name. + After the cache pop and resume=NEVER, base_dataset_path resolves to a new timestamped + directory, but _media_storage.base_path still holds the old path β€” producing broken + image references for image-column datasets. + + Fix: refresh_media_storage_path() is called after the cache pop. + """ + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + (dataset_dir / "existing_file.parquet").write_text("data") # non-empty dir triggers NEVERβ†’timestamp + + storage = _ArtifactStorage(artifact_path=tmp_path, resume=ResumeMode.IF_POSSIBLE) + stub_resource_provider.artifact_storage = storage + + # Trigger validate_folder_names so _media_storage is initialised with IF_POSSIBLE semantics + # (non-empty dir + IF_POSSIBLE β†’ resolved_dataset_name returns "dataset", not timestamped) + original_media_base = storage.media_storage.base_path + + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + with patch.object(builder, "_check_resume_config_compatibility", return_value=_ConfigCompatibility.INCOMPATIBLE): + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_run_mcp_tool_check_if_needed"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + + new_media_base = storage.media_storage.base_path + assert new_media_base != original_media_base, ( + "media_storage.base_path must be updated to the new timestamped directory after IF_POSSIBLE β†’ NEVER downgrade" + ) + assert new_media_base == storage.base_dataset_path, ( + "media_storage.base_path must match base_dataset_path after downgrade" + ) + + +def test_if_possible_starts_fresh_when_no_existing_directory( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """IF_POSSIBLE on a first-ever run (no dataset directory) must start fresh, not raise. + + Bug: _check_resume_config_compatibility returned True when config_path did not exist, + which caused IF_POSSIBLE to upgrade to ALWAYS. resolved_dataset_name then raised + ArtifactStorageError because ALWAYS requires an existing directory. + + Fix: return False when the dataset directory itself is absent. + """ + storage = _ArtifactStorage(artifact_path=tmp_path, resume=ResumeMode.IF_POSSIBLE) + stub_resource_provider.artifact_storage = storage + + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_run_mcp_tool_check_if_needed"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + + assert storage.resume == ResumeMode.NEVER + + +def test_if_possible_starts_fresh_when_directory_is_empty(stub_resource_provider, stub_test_config_builder, tmp_path): + """IF_POSSIBLE on an empty dataset directory must start fresh, not raise. + + Edge case: a prior run crashed in the window between mkdir and the first file write + inside _write_builder_config, leaving an empty directory. _check_resume_config_compatibility + previously returned True (config file absent β†’ assume compatible), causing IF_POSSIBLE to + upgrade to ALWAYS, which then raised ArtifactStorageError because the directory is empty. + + Fix: treat an empty directory the same as a missing one β€” return False. + """ + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() # empty β€” no files written yet + + storage = _ArtifactStorage(artifact_path=tmp_path, resume=ResumeMode.IF_POSSIBLE) + stub_resource_provider.artifact_storage = storage + + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + with patch.object(builder, "_run_model_health_check_if_needed"): + with patch.object(builder, "_run_mcp_tool_check_if_needed"): + with patch.object(builder, "_write_builder_config"): + with patch.object(builder, "_initialize_generators_and_graph", return_value=([], None)): + with patch.object(builder.batch_manager, "start"): + with patch.object(builder.batch_manager, "finish"): + with patch.object(builder._processor_runner, "run_after_generation"): + builder.build(num_records=2, resume=ResumeMode.IF_POSSIBLE) + + assert storage.resume == ResumeMode.NEVER diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py index e2529e1fa..cf3a600bf 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dataset_batch_manager.py @@ -451,3 +451,40 @@ def test_full_workflow(stub_batch_manager): # Verify all files exist assert stub_batch_manager.artifact_storage.metadata_file_path.exists() assert len(list(stub_batch_manager.artifact_storage.final_dataset_path.glob("*.parquet"))) == 3 + + +# --------------------------------------------------------------------------- +# start() with resume parameters +# --------------------------------------------------------------------------- + + +def test_start_with_start_batch(stub_batch_manager): + """start_batch shifts _current_batch_number so the loop skips already-done batches.""" + stub_batch_manager.start(num_records=10, buffer_size=3, start_batch=2) + + assert stub_batch_manager._current_batch_number == 2 + assert stub_batch_manager.num_batches == 4 + assert stub_batch_manager.buffer_is_empty is True + + +def test_start_with_initial_actual_num_records(stub_batch_manager): + """initial_actual_num_records pre-populates the running total for resumed runs.""" + stub_batch_manager.start(num_records=10, buffer_size=3, initial_actual_num_records=6) + + assert stub_batch_manager._actual_num_records == 6 + + +def test_start_with_start_batch_and_initial_actual_num_records(stub_batch_manager): + """Both resume params can be set together.""" + stub_batch_manager.start(num_records=10, buffer_size=3, start_batch=2, initial_actual_num_records=6) + + assert stub_batch_manager._current_batch_number == 2 + assert stub_batch_manager._actual_num_records == 6 + + +def test_start_default_values_unchanged(stub_batch_manager): + """Default call (no resume params) still starts at batch 0 with 0 actual records.""" + stub_batch_manager.start(num_records=10, buffer_size=3) + + assert stub_batch_manager._current_batch_number == 0 + assert stub_batch_manager._actual_num_records == 0 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 37b6f71ac..08dad3730 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -223,3 +223,35 @@ def test_checkpoint_calls_on_complete_when_all_rows_dropped() -> None: callback.assert_called_once_with(None) storage.write_batch_to_parquet_file.assert_not_called() + + +def test_initial_actual_num_records() -> None: + """initial_actual_num_records pre-seeds the actual_num_records counter.""" + storage = _mock_artifact_storage() + storage.write_batch_to_parquet_file.return_value = "/fake/path.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake/final.parquet" + + mgr = RowGroupBufferManager(storage, initial_actual_num_records=10) + mgr.init_row_group(0, 3) + mgr.update_batch(0, "col", ["a", "b", "c"]) + mgr.checkpoint_row_group(0) + + assert mgr.actual_num_records == 13 + + +def test_initial_total_num_batches_reflected_in_metadata() -> None: + """initial_total_num_batches pre-seeds the batch counter used by write_metadata.""" + storage = _mock_artifact_storage() + storage.write_batch_to_parquet_file.return_value = "/fake/path.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake/final.parquet" + + mgr = RowGroupBufferManager(storage, initial_actual_num_records=5, initial_total_num_batches=2) + mgr.init_row_group(2, 2) + mgr.update_batch(2, "col", ["x", "y"]) + mgr.checkpoint_row_group(2) + + mgr.write_metadata(target_num_records=9, buffer_size=3) + + written = storage.write_metadata.call_args[0][0] + assert written["num_completed_batches"] == 3 # 2 initial + 1 new + assert written["actual_num_records"] == 7 # 5 initial + 2 new diff --git a/packages/data-designer-engine/tests/engine/storage/test_artifact_storage.py b/packages/data-designer-engine/tests/engine/storage/test_artifact_storage.py index 6206d5bbc..17e30d860 100644 --- a/packages/data-designer-engine/tests/engine/storage/test_artifact_storage.py +++ b/packages/data-designer-engine/tests/engine/storage/test_artifact_storage.py @@ -13,7 +13,7 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.utils.io_helpers import load_processor_dataset from data_designer.engine.dataset_builders.errors import ArtifactStorageError -from data_designer.engine.storage.artifact_storage import ArtifactStorage, BatchStage +from data_designer.engine.storage.artifact_storage import ArtifactStorage, BatchStage, ResumeMode @pytest.fixture @@ -412,3 +412,77 @@ def test_standalone_load_processor_dataset_raises_file_not_found(tmp_path): """Standalone function raises FileNotFoundError (not ArtifactStorageError).""" with pytest.raises(FileNotFoundError, match="No artifacts found"): load_processor_dataset(tmp_path, "nonexistent") + + +# --------------------------------------------------------------------------- +# Resume flag tests +# --------------------------------------------------------------------------- + + +def test_resolved_dataset_name_creates_timestamped_copy_when_folder_exists(tmp_path): + """Default behaviour: existing non-empty folder gets a timestamped sibling.""" + existing = tmp_path / "dataset" + existing.mkdir() + (existing / "some_file.txt").write_text("x") + + storage = ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset") + name = storage.resolved_dataset_name + assert name != "dataset" + assert name.startswith("dataset_") + + +def test_resolved_dataset_name_resume_uses_existing_folder(tmp_path): + """With resume=ALWAYS, an existing non-empty folder is used as-is.""" + existing = tmp_path / "dataset" + existing.mkdir() + (existing / "some_file.txt").write_text("x") + + storage = ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset", resume=ResumeMode.ALWAYS) + assert storage.resolved_dataset_name == "dataset" + + +def test_resolved_dataset_name_resume_raises_when_no_existing_folder(tmp_path): + """With resume=ALWAYS, missing dataset folder raises ArtifactStorageError.""" + with pytest.raises(ArtifactStorageError, match="Cannot resume"): + ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset", resume=ResumeMode.ALWAYS) + + +def test_resolved_dataset_name_resume_raises_when_folder_is_empty(tmp_path): + """With resume=ALWAYS, an empty existing folder raises ArtifactStorageError.""" + (tmp_path / "dataset").mkdir() + + with pytest.raises(ArtifactStorageError, match="Cannot resume"): + ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset", resume=ResumeMode.ALWAYS) + + +def test_resolved_dataset_name_if_possible_uses_existing_folder(tmp_path): + """With resume=IF_POSSIBLE, an existing non-empty folder is used as-is.""" + existing = tmp_path / "dataset" + existing.mkdir() + (existing / "some_file.txt").write_text("x") + + storage = ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset", resume=ResumeMode.IF_POSSIBLE) + assert storage.resolved_dataset_name == "dataset" + + +def test_resolved_dataset_name_if_possible_uses_clean_name_when_no_existing_folder(tmp_path): + """With resume=IF_POSSIBLE, a missing dataset folder results in a fresh run (no error).""" + storage = ArtifactStorage(artifact_path=tmp_path, dataset_name="dataset", resume=ResumeMode.IF_POSSIBLE) + assert storage.resolved_dataset_name == "dataset" + + +def test_clear_partial_results_removes_partial_folder(tmp_path, stub_sample_dataframe): + """clear_partial_results() deletes the partial results directory and its contents.""" + storage = ArtifactStorage(artifact_path=tmp_path) + storage.write_batch_to_parquet_file(0, stub_sample_dataframe, BatchStage.PARTIAL_RESULT) + assert storage.partial_results_path.exists() + + storage.clear_partial_results() + assert not storage.partial_results_path.exists() + + +def test_clear_partial_results_is_noop_when_no_partial_folder(tmp_path): + """clear_partial_results() does not raise when the partial results folder is absent.""" + storage = ArtifactStorage(artifact_path=tmp_path) + assert not storage.partial_results_path.exists() + storage.clear_partial_results() # must not raise diff --git a/packages/data-designer/src/data_designer/cli/commands/create.py b/packages/data-designer/src/data_designer/cli/commands/create.py index ea98222ea..5a739c25a 100644 --- a/packages/data-designer/src/data_designer/cli/commands/create.py +++ b/packages/data-designer/src/data_designer/cli/commands/create.py @@ -8,6 +8,7 @@ from data_designer.cli.controllers.generation_controller import GenerationController from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS +from data_designer.engine.storage.artifact_storage import ResumeMode from data_designer.interface.results import SUPPORTED_EXPORT_FORMATS @@ -37,6 +38,18 @@ def create_command( "-o", help="Path where generated artifacts will be stored. Defaults to ./artifacts.", ), + resume: ResumeMode = typer.Option( + ResumeMode.NEVER, + "--resume", + "-r", + help=( + "Resume an interrupted generation run. " + "'never' (default): always start fresh. " + "'always': resume from the last checkpoint; raise if config changed. " + "'if_possible': resume if config matches, otherwise start fresh silently." + ), + case_sensitive=False, + ), output_format: str | None = typer.Option( None, "--output-format", @@ -61,8 +74,11 @@ def create_command( # Create with custom settings data-designer create my_config.yaml --num-records 1000 --dataset-name my_dataset - # Create from a remote config URL - data-designer create https://example.com/my_config.json --dataset-name my_dataset + # Resume an interrupted run + data-designer create my_config.yaml --resume always + + # Resume if config unchanged, otherwise start fresh + data-designer create my_config.yaml --resume if_possible # Create from a Python module with custom output path data-designer create my_config.py --artifact-path /path/to/output @@ -73,5 +89,6 @@ def create_command( num_records=num_records, dataset_name=dataset_name, artifact_path=artifact_path, + resume=resume, output_format=output_format, ) diff --git a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py index 39c45f5f5..4a4231c41 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py @@ -16,6 +16,7 @@ from data_designer.cli.utils.sample_records_pager import PAGER_FILENAME, create_sample_records_pager from data_designer.config.errors import InvalidConfigError from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH +from data_designer.engine.storage.artifact_storage import ResumeMode from data_designer.interface import DataDesigner from data_designer.logging import LOG_INDENT @@ -116,6 +117,7 @@ def run_create( num_records: int, dataset_name: str, artifact_path: str | None, + resume: ResumeMode = ResumeMode.NEVER, output_format: str | None = None, ) -> None: """Load config, create a full dataset, and save results to disk. @@ -125,6 +127,7 @@ def run_create( num_records: Number of records to generate. dataset_name: Name for the generated dataset folder. artifact_path: Path where generated artifacts will be stored, or None for default. + resume: Controls how interrupted runs are handled. output_format: If set, export the dataset to a single file in this format after generation. One of 'jsonl', 'csv', 'parquet'. """ @@ -145,6 +148,7 @@ def run_create( config_builder, num_records=num_records, dataset_name=dataset_name, + resume=resume, ) except Exception as e: print_error(f"Dataset creation failed: {e}") diff --git a/packages/data-designer/src/data_designer/interface/__init__.py b/packages/data-designer/src/data_designer/interface/__init__.py index a8a1bd61b..c5e426e35 100644 --- a/packages/data-designer/src/data_designer/interface/__init__.py +++ b/packages/data-designer/src/data_designer/interface/__init__.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from data_designer.engine.storage.artifact_storage import ResumeMode # noqa: F401 from data_designer.interface.data_designer import DataDesigner # noqa: F401 from data_designer.interface.errors import ( # noqa: F401 DataDesignerEarlyShutdownError, @@ -21,6 +22,7 @@ "DataDesignerGenerationError": ("data_designer.interface.errors", "DataDesignerGenerationError"), "DataDesignerProfilingError": ("data_designer.interface.errors", "DataDesignerProfilingError"), "DatasetCreationResults": ("data_designer.interface.results", "DatasetCreationResults"), + "ResumeMode": ("data_designer.engine.storage.artifact_storage", "ResumeMode"), } __all__ = list(_LAZY_IMPORTS.keys()) diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 9f142afa6..bbdeb6ddc 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -61,7 +61,7 @@ PlaintextResolver, SecretResolver, ) -from data_designer.engine.storage.artifact_storage import ArtifactStorage +from data_designer.engine.storage.artifact_storage import ArtifactStorage, ResumeMode from data_designer.interface.errors import ( DataDesignerEarlyShutdownError, DataDesignerGenerationError, @@ -217,6 +217,7 @@ def create( *, num_records: int = DEFAULT_NUM_RECORDS, dataset_name: str = "dataset", + resume: ResumeMode = ResumeMode.NEVER, ) -> DatasetCreationResults: """Create dataset and save results to the local artifact storage. @@ -234,6 +235,20 @@ def create( a datetime stamp. For example, if the dataset name is "awesome_dataset" and a directory with the same name already exists, the dataset will be saved to a new directory with the name "awesome_dataset_2025-01-01_12-00-00". + resume: Controls how interrupted runs are handled. + + - ``ResumeMode.NEVER`` (default): always start a fresh generation run. + - ``ResumeMode.ALWAYS``: resume from the last completed batch (sync) or row group + (async). ``buffer_size`` must match the original run. ``num_records`` may be + equal to or greater than what was already generated (you can extend the dataset); + ``num_records`` less than actual records so far raises ``DatasetGenerationError``. + If no checkpoint exists yet (interrupted before the first batch finished), silently + restarts from the beginning. Raises if the stored config is incompatible. + - ``ResumeMode.IF_POSSIBLE``: like ``ALWAYS`` when the current config fingerprint + matches the stored config; otherwise starts a fresh run without raising an error. + + In all resume modes, in-flight partial results from the interrupted run are + discarded before generation continues. Returns: DatasetCreationResults object with methods for loading the generated dataset, @@ -246,11 +261,11 @@ def create( logger.info("🎨 Creating Data Designer dataset") self._log_jinja_rendering_engine_mode() - resource_provider = self._create_resource_provider(dataset_name, config_builder) + resource_provider = self._create_resource_provider(dataset_name, config_builder, resume=resume) try: builder = self._create_dataset_builder(config_builder.build(), resource_provider) - builder.build(num_records=num_records) + builder.build(num_records=num_records, resume=resume) except DeprecationWarning: raise except Exception as e: @@ -561,7 +576,7 @@ def _create_dataset_profiler( ) def _create_resource_provider( - self, dataset_name: str, config_builder: DataDesignerConfigBuilder + self, dataset_name: str, config_builder: DataDesignerConfigBuilder, *, resume: ResumeMode = ResumeMode.NEVER ) -> ResourceProvider: ArtifactStorage.mkdir_if_needed(self._artifact_path) @@ -570,7 +585,9 @@ def _create_resource_provider( seed_dataset_source = seed_config.source return create_resource_provider( - artifact_storage=ArtifactStorage(artifact_path=self._artifact_path, dataset_name=dataset_name), + artifact_storage=ArtifactStorage( + artifact_path=self._artifact_path, dataset_name=dataset_name, resume=resume + ), model_configs=config_builder.model_configs, secret_resolver=self._secret_resolver, model_provider_registry=self._model_provider_registry, diff --git a/packages/data-designer/tests/cli/commands/test_create_command.py b/packages/data-designer/tests/cli/commands/test_create_command.py index fc779df7c..8b3335d4e 100644 --- a/packages/data-designer/tests/cli/commands/test_create_command.py +++ b/packages/data-designer/tests/cli/commands/test_create_command.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from data_designer.cli.commands.create import create_command +from data_designer.engine.storage.artifact_storage import ResumeMode # --------------------------------------------------------------------------- # create_command delegation tests @@ -19,7 +20,12 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None, output_format=None + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, ) mock_ctrl_cls.assert_called_once() @@ -28,6 +34,7 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non num_records=10, dataset_name="dataset", artifact_path=None, + resume=ResumeMode.NEVER, output_format=None, ) @@ -43,6 +50,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: num_records=100, dataset_name="my_data", artifact_path="/custom/output", + resume=ResumeMode.NEVER, output_format=None, ) @@ -51,6 +59,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: num_records=100, dataset_name="my_data", artifact_path="/custom/output", + resume=ResumeMode.NEVER, output_format=None, ) @@ -62,7 +71,12 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock) mock_ctrl_cls.return_value = mock_ctrl create_command( - config_source="config.yaml", num_records=5, dataset_name="ds", artifact_path=None, output_format=None + config_source="config.yaml", + num_records=5, + dataset_name="ds", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -70,6 +84,57 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock) num_records=5, dataset_name="ds", artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, + ) + + +@patch("data_designer.cli.commands.create.GenerationController") +def test_create_command_passes_resume_always(mock_ctrl_cls: MagicMock) -> None: + """Test create_command forwards --resume always to the controller.""" + mock_ctrl = MagicMock() + mock_ctrl_cls.return_value = mock_ctrl + + create_command( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.ALWAYS, + output_format=None, + ) + + mock_ctrl.run_create.assert_called_once_with( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.ALWAYS, + output_format=None, + ) + + +@patch("data_designer.cli.commands.create.GenerationController") +def test_create_command_passes_resume_if_possible(mock_ctrl_cls: MagicMock) -> None: + """Test create_command forwards --resume if_possible to the controller.""" + mock_ctrl = MagicMock() + mock_ctrl_cls.return_value = mock_ctrl + + create_command( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.IF_POSSIBLE, + output_format=None, + ) + + mock_ctrl.run_create.assert_called_once_with( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.IF_POSSIBLE, output_format=None, ) @@ -85,6 +150,7 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: num_records=10, dataset_name="dataset", artifact_path=None, + resume=ResumeMode.NEVER, output_format="jsonl", ) @@ -93,5 +159,6 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: num_records=10, dataset_name="dataset", artifact_path=None, + resume=ResumeMode.NEVER, output_format="jsonl", ) diff --git a/packages/data-designer/tests/cli/controllers/test_generation_controller.py b/packages/data-designer/tests/cli/controllers/test_generation_controller.py index 151f2cbb4..b8047641a 100644 --- a/packages/data-designer/tests/cli/controllers/test_generation_controller.py +++ b/packages/data-designer/tests/cli/controllers/test_generation_controller.py @@ -14,6 +14,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.errors import InvalidConfigError from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH +from data_designer.engine.storage.artifact_storage import ResumeMode _CTRL = "data_designer.cli.controllers.generation_controller" _DW = DEFAULT_DISPLAY_WIDTH @@ -27,8 +28,8 @@ def _make_mock_preview_results(num_records: int) -> MagicMock: return mock_results -def _make_mock_create_results(num_records: int, base_path: str = "/output/artifacts/dataset") -> MagicMock: - """Create a mock CreateResults with the given number of records.""" +def _make_mock_create_results(num_records: int = 0, base_path: str = "/output/artifacts/dataset") -> MagicMock: + """Create a mock DatasetCreationResults.""" mock_results = MagicMock() mock_results.count_records.return_value = num_records mock_results.artifact_storage.base_dataset_path = base_path @@ -675,14 +676,16 @@ def test_run_create_success(mock_load_config: MagicMock, mock_dd_cls: MagicMock) mock_dd = MagicMock() mock_dd_cls.return_value = mock_dd - mock_dd.create.return_value = _make_mock_create_results(10) + mock_dd.create.return_value = _make_mock_create_results() controller = GenerationController() controller.run_create(config_source="config.yaml", num_records=10, dataset_name="dataset", artifact_path=None) mock_load_config.assert_called_once_with("config.yaml") mock_dd_cls.assert_called_once_with(artifact_path=Path.cwd() / "artifacts") - mock_dd.create.assert_called_once_with(mock_builder, num_records=10, dataset_name="dataset") + mock_dd.create.assert_called_once_with( + mock_builder, num_records=10, dataset_name="dataset", resume=ResumeMode.NEVER + ) @patch(f"{_CTRL}.DataDesigner") @@ -692,7 +695,7 @@ def test_run_create_custom_options(mock_load_config: MagicMock, mock_dd_cls: Mag mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) mock_dd = MagicMock() mock_dd_cls.return_value = mock_dd - mock_dd.create.return_value = _make_mock_create_results(100, "/custom/output/my_data") + mock_dd.create.return_value = _make_mock_create_results(base_path="/custom/output/my_data") controller = GenerationController() controller.run_create( @@ -703,7 +706,9 @@ def test_run_create_custom_options(mock_load_config: MagicMock, mock_dd_cls: Mag ) mock_dd_cls.assert_called_once_with(artifact_path=Path("/custom/output")) - mock_dd.create.assert_called_once_with(mock_load_config.return_value, num_records=100, dataset_name="my_data") + mock_dd.create.assert_called_once_with( + mock_load_config.return_value, num_records=100, dataset_name="my_data", resume=ResumeMode.NEVER + ) @patch(f"{_CTRL}.load_config_builder") @@ -741,7 +746,7 @@ def test_run_create_calls_to_report_when_analysis_present(mock_load_config: Magi mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) mock_dd = MagicMock() mock_dd_cls.return_value = mock_dd - mock_results = _make_mock_create_results(10) + mock_results = _make_mock_create_results() mock_analysis = MagicMock() mock_results.load_analysis.return_value = mock_analysis mock_dd.create.return_value = mock_results @@ -760,7 +765,7 @@ def test_run_create_skips_report_when_analysis_is_none(mock_load_config: MagicMo mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) mock_dd = MagicMock() mock_dd_cls.return_value = mock_dd - mock_results = _make_mock_create_results(10) + mock_results = _make_mock_create_results() mock_results.load_analysis.return_value = None mock_dd.create.return_value = mock_results @@ -772,6 +777,29 @@ def test_run_create_skips_report_when_analysis_is_none(mock_load_config: MagicMo mock_results.load_analysis.assert_called_once() +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_create_passes_resume_always(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: + """run_create forwards resume=ALWAYS to DataDesigner.create().""" + mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) + mock_dd = MagicMock() + mock_dd_cls.return_value = mock_dd + mock_dd.create.return_value = _make_mock_create_results() + + controller = GenerationController() + controller.run_create( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.ALWAYS, + ) + + mock_dd.create.assert_called_once_with( + mock_load_config.return_value, num_records=10, dataset_name="dataset", resume=ResumeMode.ALWAYS + ) + + @patch(f"{_CTRL}.DataDesigner") @patch(f"{_CTRL}.load_config_builder") def test_run_create_with_output_format_happy_path(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: @@ -796,6 +824,29 @@ def test_run_create_with_output_format_happy_path(mock_load_config: MagicMock, m ) +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_create_passes_resume_if_possible(mock_load_config: MagicMock, mock_dd_cls: MagicMock) -> None: + """run_create forwards resume=IF_POSSIBLE to DataDesigner.create().""" + mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) + mock_dd = MagicMock() + mock_dd_cls.return_value = mock_dd + mock_dd.create.return_value = _make_mock_create_results() + + controller = GenerationController() + controller.run_create( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.IF_POSSIBLE, + ) + + mock_dd.create.assert_called_once_with( + mock_load_config.return_value, num_records=10, dataset_name="dataset", resume=ResumeMode.IF_POSSIBLE + ) + + @patch(f"{_CTRL}.DataDesigner") @patch(f"{_CTRL}.load_config_builder") def test_run_create_export_failure_exits(mock_load_config: MagicMock, mock_dd_cls: MagicMock, tmp_path: Path) -> None: diff --git a/packages/data-designer/tests/cli/test_main.py b/packages/data-designer/tests/cli/test_main.py index 56dc7d2ad..33b620a35 100644 --- a/packages/data-designer/tests/cli/test_main.py +++ b/packages/data-designer/tests/cli/test_main.py @@ -11,6 +11,7 @@ from data_designer.cli.main import app, main from data_designer.cli.version_notice import UpdateNotice from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS +from data_designer.engine.storage.artifact_storage import ResumeMode runner = CliRunner() @@ -163,5 +164,6 @@ def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None: num_records=DEFAULT_NUM_RECORDS, dataset_name="dataset", artifact_path=None, + resume=ResumeMode.NEVER, output_format=None, )