diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index ace57190f..8f8637ca4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -41,6 +41,7 @@ from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver from data_designer.engine.dataset_builders.scheduling.resources import ( SchedulableTask, + request_scheduler_resource_key, stable_task_id, ) from data_designer.engine.dataset_builders.scheduling.task_admission import ( @@ -145,6 +146,15 @@ class _DispatchOutcome: admission_blocked: bool = False +@dataclass(frozen=True) +class _DeferredAdmissionAnalysis: + """Deferred retry pressure as seen by adaptive row-group admission.""" + + blocks: bool + candidate_columns: tuple[str, ...] + independent_candidate_columns: tuple[str, ...] + + class AsyncTaskScheduler: """Dependency-aware async task scheduler for the dataset builder. @@ -329,6 +339,12 @@ def __init__( self._row_group_admission_pressure_ticks = 0 self._row_group_admission_blocked_reasons: Counter[str] = Counter() self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail() + self._row_group_admission_pending: tuple[int, int] | None = None + self._deferred_admission_cache: tuple[tuple[Task, ...], tuple[int, int], _DeferredAdmissionAnalysis] | None = ( + None + ) + self._transitive_upstream_cache: dict[str, frozenset[str]] = {} + self._transitive_downstream_cache: dict[str, frozenset[str]] = {} self._request_pressure_provider = request_pressure_provider self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None self._request_pressure_advisory_skips = 0 @@ -944,13 +960,14 @@ def _maybe_update_adaptive_row_group_target(self) -> None: self._row_group_admission_event.set() def _adaptive_row_group_block_reason(self) -> str | None: - if self._deferred: - return "deferred_tasks" - next_size = self._next_unadmitted_row_group_size() - if next_size is None: + next_row_group = self._next_unadmitted_row_group() + if next_row_group is None: return "no_pending_row_groups" + next_rg_id, next_size = next_row_group if not self._row_group_row_guard_allows(next_size): return "max_admitted_rows" + if self._deferred and self._deferred_admission_analysis(next_rg_id, next_size).blocks: + return "deferred_tasks" queue_view = self._fair_queue.view() queue_guard = self._max_in_flight_tasks * 4 if queue_view.queued_total >= queue_guard: @@ -968,13 +985,169 @@ def _adaptive_row_group_block_reason(self) -> str | None: return "queued_llm_demand" return None - def _next_unadmitted_row_group_size(self) -> int | None: - for rg_id, rg_size in self._row_groups: - if rg_id not in self._rg_states and not self._tracker.is_row_group_complete( - rg_id, rg_size, self._graph.columns - ): - return rg_size - return None + def _next_unadmitted_row_group(self) -> tuple[int, int] | None: + pending = self._row_group_admission_pending + if pending is None: + return None + rg_id, rg_size = pending + if rg_id in self._rg_states or self._tracker.is_row_group_complete(rg_id, rg_size, self._graph.columns): + return None + return pending + + def _deferred_admission_analysis(self, row_group: int, row_group_size: int) -> _DeferredAdmissionAnalysis: + cache_key = (tuple(self._deferred), (row_group, row_group_size)) + if self._deferred_admission_cache is not None and self._deferred_admission_cache[:2] == cache_key: + return self._deferred_admission_cache[2] + deferred_items = tuple(self._schedulable_task(task) for task in self._deferred) + deferred_keys = {key for item in deferred_items for key in self._localized_deferred_admission_keys(item)} + candidates = tuple( + (item, self._localized_deferred_admission_keys(item)) + for item in self._row_group_admission_candidate_tasks(row_group, row_group_size) + ) + blocked_columns: set[str] = set() + for item in deferred_items: + blocked_columns.update(self._task_output_columns(item.payload)) + for item, keys in candidates: + if keys & deferred_keys: + blocked_columns.update(self._task_output_columns(item.payload)) + independent_candidates = tuple( + item.payload.column + for item, keys in candidates + if not (keys & deferred_keys) + and not self._task_depends_on_any(item.payload, blocked_columns) + and ( + self._is_resource_scoped_admission_candidate(item) + or not self._task_reaches_any(item.payload, blocked_columns) + ) + ) + blocks = bool(deferred_items) and not independent_candidates + analysis = _DeferredAdmissionAnalysis( + blocks=blocks, + candidate_columns=tuple(item.payload.column for item, _keys in candidates), + independent_candidate_columns=independent_candidates, + ) + self._deferred_admission_cache = (*cache_key, analysis) + return analysis + + def _row_group_admission_candidate_tasks( + self, + row_group: int, + row_group_size: int, + ) -> tuple[SchedulableTask, ...]: + tasks: list[SchedulableTask] = [] + seen_generators: set[int] = set() + for column in self._graph.get_topological_order(): + generator_id = id(self._generators[column]) + if generator_id in seen_generators: + continue + seen_generators.add(generator_id) + strategy = self._graph.get_strategy(column) + if strategy == GenerationStrategy.CELL_BY_CELL: + if row_group_size <= 0: + continue + task = Task(column=column, row_group=row_group, row_index=0, task_type="cell") + elif column in self._seed_cols: + task = Task(column=column, row_group=row_group, row_index=None, task_type="from_scratch") + else: + task = Task(column=column, row_group=row_group, row_index=None, task_type="batch") + tasks.append(self._schedulable_task(task)) + return tuple(tasks) + + def _localized_deferred_admission_keys(self, item: SchedulableTask) -> set[str]: + if item.request_resource_key is not None: + resource = item.request_resource_key + return { + f"request_resource:{_request_resource_label(resource)}", + f"scheduler_resource:{request_scheduler_resource_key(resource)}", + } + identity = "/".join(item.group.key.identity) + return {f"group:{item.group.key.kind}:{identity}"} + + @staticmethod + def _is_localized_admission_resource(resource: str) -> bool: + return resource.startswith("request:") + + def _is_resource_scoped_admission_candidate(self, item: SchedulableTask) -> bool: + return item.request_resource_key is not None or item.group.key.kind != "local" + + def _task_output_columns(self, task: Task) -> tuple[str, ...]: + return self._task_flow_identity(task) or (task.column,) + + def _task_depends_on_any(self, task: Task, blocked_columns: set[str]) -> bool: + return any(self._column_depends_on_any(column, blocked_columns) for column in self._task_output_columns(task)) + + def _task_reaches_any(self, task: Task, blocked_columns: set[str]) -> bool: + return any(self._column_reaches_any(column, blocked_columns) for column in self._task_output_columns(task)) + + def _column_depends_on_any(self, column: str, blocked_columns: set[str]) -> bool: + return bool(self._transitive_upstream_columns(column) & blocked_columns) + + def _column_reaches_any(self, column: str, blocked_columns: set[str]) -> bool: + return bool(self._transitive_downstream_columns(column) & blocked_columns) + + def _transitive_upstream_columns(self, column: str) -> frozenset[str]: + cached = self._transitive_upstream_cache.get(column) + if cached is not None: + return cached + result = self._walk_graph(column, upstream=True) + self._transitive_upstream_cache[column] = result + return result + + def _transitive_downstream_columns(self, column: str) -> frozenset[str]: + cached = self._transitive_downstream_cache.get(column) + if cached is not None: + return cached + result = self._walk_graph(column, upstream=False) + self._transitive_downstream_cache[column] = result + return result + + def _walk_graph(self, column: str, *, upstream: bool) -> frozenset[str]: + next_columns = self._graph.get_upstream_columns if upstream else self._graph.get_downstream_columns + to_visit = list(next_columns(column)) + seen: set[str] = set() + while to_visit: + next_column = to_visit.pop() + if next_column in seen: + continue + seen.add(next_column) + to_visit.extend(next_columns(next_column)) + return frozenset(seen) + + def _deferred_admission_diagnostics(self) -> dict[str, object]: + deferred_items = tuple(self._schedulable_task(task) for task in self._deferred) + diagnostics: dict[str, object] = { + "count": len(self._deferred), + "scope": "localized" if self._deferred else "none", + "blocks_next_row_group": False, + "columns": dict(Counter(task.column for task in self._deferred)), + "request_resources": {}, + "scheduler_resources": {}, + "candidate_columns": (), + "independent_candidate_columns": (), + } + if not self._deferred: + return diagnostics + request_resource_counts = Counter( + label + for item in deferred_items + if (label := _request_resource_label(item.request_resource_key)) is not None + ) + scheduler_resource_counts = Counter( + resource + for item in deferred_items + for resource in item.resource_request.amounts + if self._is_localized_admission_resource(resource) + ) + diagnostics["request_resources"] = dict(request_resource_counts) + diagnostics["scheduler_resources"] = dict(scheduler_resource_counts) + next_row_group = self._next_unadmitted_row_group() + if next_row_group is None: + return diagnostics + analysis = self._deferred_admission_analysis(*next_row_group) + diagnostics["blocks_next_row_group"] = analysis.blocks + diagnostics["candidate_columns"] = analysis.candidate_columns + diagnostics["independent_candidate_columns"] = analysis.independent_candidate_columns + return diagnostics def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: queue_view = self._fair_queue.view() @@ -999,42 +1172,48 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: "llm_wait_leased": task_view.leased_resources.get("llm_wait", 0), "llm_wait_available": task_view.resources_available.get("llm_wait", 0), "blocked_reasons": dict(self._row_group_admission_blocked_reasons), + "deferred_admission": self._deferred_admission_diagnostics(), } async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" all_admitted = True - for rg_id, rg_size in self._row_groups: - await self._wait_for_row_group_admission_capacity(rg_size) - if self._early_shutdown or self._fatal_worker_error is not None: - all_admitted = False - break - await self._rg_semaphore.acquire() - if self._early_shutdown or self._fatal_worker_error is not None: - self._rg_semaphore.release() - all_admitted = False - break - if not self._row_group_row_guard_allows(rg_size): - self._rg_semaphore.release() + try: + for rg_id, rg_size in self._row_groups: + self._row_group_admission_pending = (rg_id, rg_size) await self._wait_for_row_group_admission_capacity(rg_size) + if self._early_shutdown or self._fatal_worker_error is not None: + all_admitted = False + break await self._rg_semaphore.acquire() if self._early_shutdown or self._fatal_worker_error is not None: self._rg_semaphore.release() all_admitted = False break - self._rg_states[rg_id] = _RowGroupState(size=rg_size) - - if self._buffer_manager is not None: - self._buffer_manager.init_row_group(rg_id, rg_size) - - await self._dispatch_seeds(rg_id, rg_size) - self._emit_scheduler_event( - "row_group_admitted", - diagnostics=self._row_group_admission_diagnostics(reason="admitted") - | {"row_group": rg_id, "row_group_size": rg_size}, - ) - self._emit_scheduler_health_snapshot("row_group_admitted") - self._wake_event.set() + if not self._row_group_row_guard_allows(rg_size): + self._rg_semaphore.release() + await self._wait_for_row_group_admission_capacity(rg_size) + await self._rg_semaphore.acquire() + if self._early_shutdown or self._fatal_worker_error is not None: + self._rg_semaphore.release() + all_admitted = False + break + self._row_group_admission_pending = None + self._rg_states[rg_id] = _RowGroupState(size=rg_size) + + if self._buffer_manager is not None: + self._buffer_manager.init_row_group(rg_id, rg_size) + + await self._dispatch_seeds(rg_id, rg_size) + self._emit_scheduler_event( + "row_group_admitted", + diagnostics=self._row_group_admission_diagnostics(reason="admitted") + | {"row_group": rg_id, "row_group_size": rg_size}, + ) + self._emit_scheduler_health_snapshot("row_group_admitted") + self._wake_event.set() + finally: + self._row_group_admission_pending = None self._all_rgs_admitted = all_admitted self._wake_event.set() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index bb5cf5685..a3934eb03 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -3156,19 +3156,21 @@ def __init__( *args: Any, provider_name: str = "provider", model_id: str = "model", + generation_kind: str = "chat", request_weight: int = 1, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self._provider_name = provider_name self._model_id = model_id + self._generation_kind = generation_kind self._request_weight = request_weight def get_scheduling_metadata(self) -> SchedulingMetadata: return SchedulingMetadata.model( self._provider_name, self._model_id, - "chat", + self._generation_kind, weight=self._request_weight, ) @@ -3753,6 +3755,480 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) +def _build_adaptive_model_resource_scheduler( + *, + columns: tuple[str, ...] = ("cooling", "healthy"), + row_groups: list[tuple[int, int]] | None = None, +) -> tuple[AsyncTaskScheduler, CompletionTracker]: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=column, prompt="{{ seed }}", model_alias=MODEL_ALIAS) for column in columns], + ] + strategies: dict[str, GenerationStrategy] = {"seed": GenerationStrategy.FULL_COLUMN} + strategies.update({column: GenerationStrategy.CELL_BY_CELL for column in columns}) + generators: dict[str, ColumnGenerator] = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + } + for column in columns: + generators[column] = SlowModelBoundCellGenerator( + config=_expr_config(column), + resource_provider=provider, + provider_name="provider", + model_id=column, + delay=0.0, + ) + graph = ExecutionGraph.create(configs, strategies) + row_groups = row_groups or [(0, 1), (1, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=max(2, len(row_groups)), + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=sum(size for _row_group, size in row_groups), + buffer_size=1, + ) + return scheduler, tracker + + +def _set_next_row_group_pending(scheduler: AsyncTaskScheduler, row_group: int = 1, row_group_size: int = 1) -> None: + scheduler._row_group_admission_pending = (row_group, row_group_size) + + +def test_scheduler_adaptive_row_group_ignores_unrelated_deferred_retry_resource() -> None: + scheduler, _tracker = _build_adaptive_model_resource_scheduler() + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() is None + assert diagnostics["blocks_next_row_group"] is False + assert diagnostics["scope"] == "localized" + assert diagnostics["columns"] == {"cooling": 1} + assert diagnostics["request_resources"] == {"provider/cooling/chat": 1} + assert "healthy" in diagnostics["independent_candidate_columns"] + + +def test_scheduler_adaptive_row_group_blocks_same_deferred_retry_resource() -> None: + scheduler, _tracker = _build_adaptive_model_resource_scheduler(columns=("cooling",)) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() == "deferred_tasks" + assert diagnostics["blocks_next_row_group"] is True + assert set(diagnostics["candidate_columns"]) == {"seed", "cooling"} + assert diagnostics["independent_candidate_columns"] == () + + +def test_scheduler_adaptive_row_group_blocks_downstream_candidate_behind_deferred_resource() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cooling_a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="cooling_b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="healthy", prompt="{{ cooling_b }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cooling_a": GenerationStrategy.CELL_BY_CELL, + "cooling_b": GenerationStrategy.CELL_BY_CELL, + "healthy": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cooling_a": SlowModelBoundCellGenerator( + config=_expr_config("cooling_a"), + resource_provider=provider, + provider_name="provider", + model_id="cooling", + delay=0.0, + ), + "cooling_b": SlowModelBoundCellGenerator( + config=_expr_config("cooling_b"), + resource_provider=provider, + provider_name="provider", + model_id="cooling", + delay=0.0, + ), + "healthy": SlowModelBoundCellGenerator( + config=_expr_config("healthy"), + resource_provider=provider, + provider_name="provider", + model_id="healthy", + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=2, + buffer_size=1, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling_a", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() == "deferred_tasks" + assert diagnostics["blocks_next_row_group"] is True + assert set(diagnostics["candidate_columns"]) == {"seed", "cooling_a", "cooling_b", "healthy"} + assert diagnostics["independent_candidate_columns"] == () + + +def test_scheduler_adaptive_row_group_blocks_multi_output_sibling_dependency() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cooling_a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="cooling_b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="healthy", prompt="{{ cooling_b }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cooling_a": GenerationStrategy.CELL_BY_CELL, + "cooling_b": GenerationStrategy.CELL_BY_CELL, + "healthy": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + cooling = SlowModelBoundCellGenerator( + config=_expr_config("cooling_a"), + resource_provider=provider, + provider_name="provider", + model_id="cooling", + delay=0.0, + ) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cooling_a": cooling, + "cooling_b": cooling, + "healthy": SlowModelBoundCellGenerator( + config=_expr_config("healthy"), + resource_provider=provider, + provider_name="provider", + model_id="healthy", + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=2, + buffer_size=1, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling_a", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() == "deferred_tasks" + assert diagnostics["blocks_next_row_group"] is True + assert "healthy" not in diagnostics["independent_candidate_columns"] + + +def test_scheduler_adaptive_row_group_blocks_shared_scheduler_resource_across_domains() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="chat_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="embedding_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "chat_col": GenerationStrategy.CELL_BY_CELL, + "embedding_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "chat_col": SlowModelBoundCellGenerator( + config=_expr_config("chat_col"), + resource_provider=provider, + provider_name="provider", + model_id="model", + generation_kind=RequestDomain.CHAT.value, + delay=0.0, + ), + "embedding_col": SlowModelBoundCellGenerator( + config=_expr_config("embedding_col"), + resource_provider=provider, + provider_name="provider", + model_id="model", + generation_kind=RequestDomain.EMBEDDING.value, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=2, + buffer_size=1, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="chat_col", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() == "deferred_tasks" + assert diagnostics["blocks_next_row_group"] is True + assert diagnostics["request_resources"] == {"provider/model/chat": 1} + assert diagnostics["scheduler_resources"] == {"request:provider/model": 1} + assert diagnostics["independent_candidate_columns"] == () + + +def test_scheduler_adaptive_row_group_counts_independent_local_branch() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cooling", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="local_branch", expr="'local'", dtype="str"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cooling": GenerationStrategy.CELL_BY_CELL, + "local_branch": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cooling": SlowModelBoundCellGenerator( + config=_expr_config("cooling"), + resource_provider=provider, + provider_name="provider", + model_id="cooling", + delay=0.0, + ), + "local_branch": MockCellGenerator(config=_expr_config("local_branch"), resource_provider=provider), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=2, + buffer_size=1, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() is None + assert diagnostics["blocks_next_row_group"] is False + assert "local_branch" in diagnostics["independent_candidate_columns"] + assert "seed" not in diagnostics["independent_candidate_columns"] + + +def test_scheduler_adaptive_row_group_localizes_custom_model_deferred_retry() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cooling", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="healthy", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cooling": GenerationStrategy.CELL_BY_CELL, + "healthy": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cooling": SlowLLMBoundCellGenerator( + config=_expr_config("cooling"), + resource_provider=provider, + delay=0.0, + ), + "healthy": SlowLLMBoundCellGenerator( + config=_expr_config("healthy"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + num_records=2, + buffer_size=1, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler) + deferred = Task(column="cooling", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + diagnostics = scheduler._row_group_admission_diagnostics(reason="probe")["deferred_admission"] + + assert scheduler._adaptive_row_group_block_reason() is None + assert diagnostics["blocks_next_row_group"] is False + assert diagnostics["request_resources"] == {} + assert "healthy" in diagnostics["independent_candidate_columns"] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_adaptive_row_group_admission_keeps_healthy_resource_exposed_during_deferred_cooldown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _mock_provider() + monkeypatch.setattr(async_scheduler_module, "RETRYABLE_RESALVAGE_BACKOFF_S", 0.001) + healthy_threshold = asyncio.Event() + + class AlwaysCoolingModelGenerator(SlowModelBoundCellGenerator): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.calls = 0 + + async def agenerate(self, data: dict) -> dict: + self.calls += 1 + raise ModelRateLimitError("429 Too Many Requests") + + class CountingHealthyModelGenerator(SlowModelBoundCellGenerator): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.calls = 0 + + async def agenerate(self, data: dict) -> dict: + self.calls += 1 + if self.calls >= 16: + healthy_threshold.set() + return await super().agenerate(data) + + cooling = AlwaysCoolingModelGenerator( + config=_expr_config("cooling"), + resource_provider=provider, + provider_name="provider", + model_id="cooling", + request_weight=1, + delay=0.0, + ) + healthy = CountingHealthyModelGenerator( + config=_expr_config("healthy"), + resource_provider=provider, + provider_name="provider", + model_id="healthy", + request_weight=8, + delay=0.0, + ) + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cooling", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="healthy", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cooling": GenerationStrategy.CELL_BY_CELL, + "healthy": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(row_group, 1) for row_group in range(64)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cooling": cooling, + "healthy": healthy, + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=32, + max_in_flight_tasks=32, + max_model_task_admission=32, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + salvage_max_rounds=1, + num_records=64, + buffer_size=1, + ) + + run_task = asyncio.create_task(scheduler.run()) + threshold_task = asyncio.create_task(healthy_threshold.wait()) + pending: set[asyncio.Task] = set() + try: + done, pending = await asyncio.wait( + {run_task, threshold_task}, + timeout=5.0, + return_when=asyncio.FIRST_COMPLETED, + ) + if run_task in done: + await run_task + + assert threshold_task in done + assert healthy.calls >= 16 + assert cooling.calls >= 1 + assert scheduler._observed_max_row_group_admission_target > 1 + assert scheduler._observed_max_row_groups_in_flight > 1 + finally: + for pending_task in pending: + pending_task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + run_task.cancel() + try: + await run_task + except asyncio.CancelledError: + pass + + def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> None: provider = _mock_provider() configs = [ @@ -3818,6 +4294,19 @@ def _stub_row_group_admission_resource_views( ) +def test_scheduler_adaptive_row_group_row_guard_precedes_unrelated_deferred_retry() -> None: + row_groups = [(0, 5_000), (1, 5_000)] + scheduler, _tracker = _build_adaptive_model_resource_scheduler(row_groups=row_groups) + scheduler._rg_states[0] = SimpleNamespace(size=5_000, seeds_dispatched=True, pre_batch_done=True, in_flight_count=0) + _set_next_row_group_pending(scheduler, row_group=1, row_group_size=5_000) + deferred = Task(column="cooling", row_group=0, row_index=0, task_type="cell") + scheduler._deferred = [deferred] + scheduler._deferred_errors[deferred] = ModelRateLimitError("429 Too Many Requests") + + assert scheduler._adaptive_max_admitted_rows == 8_192 + assert scheduler._adaptive_row_group_block_reason() == "max_admitted_rows" + + def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> None: provider = _mock_provider() configs = [ @@ -3855,6 +4344,7 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N llm_available=0, llm_leased=1, ) + _set_next_row_group_pending(scheduler) assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated" @@ -3869,6 +4359,7 @@ def test_scheduler_adaptive_row_group_block_reason_allows_zero_llm_lease_bootstr llm_available=3, llm_leased=0, ) + _set_next_row_group_pending(scheduler) assert scheduler._adaptive_row_group_block_reason() is None @@ -3883,6 +4374,7 @@ def test_scheduler_adaptive_row_group_block_reason_blocks_queued_llm_demand_afte llm_available=3, llm_leased=1, ) + _set_next_row_group_pending(scheduler) assert scheduler._adaptive_row_group_block_reason() == "queued_llm_demand" @@ -3938,6 +4430,7 @@ def test_scheduler_adaptive_row_group_target_grows_for_zero_llm_lease_bootstrap( llm_available=3, llm_leased=0, ) + _set_next_row_group_pending(scheduler) scheduler._maybe_update_adaptive_row_group_target() assert scheduler._row_group_admission_target == 1 @@ -3966,6 +4459,7 @@ def test_scheduler_adaptive_row_group_target_stays_blocked_after_llm_lease_boots llm_available=3, llm_leased=1, ) + _set_next_row_group_pending(scheduler) scheduler._maybe_update_adaptive_row_group_target() scheduler._maybe_update_adaptive_row_group_target() @@ -3982,6 +4476,7 @@ def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> N scheduler._fair_queue = SimpleNamespace( view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={}) ) + _set_next_row_group_pending(scheduler) assert scheduler._adaptive_row_group_block_reason() == "queued_task_guardrail"