Skip to content

Commit 0b56d4d

Browse files
airhornsclaude
andcommitted
Feature: Run audits concurrently using concurrent_tasks setting
Adds two levels of audit concurrency: 1. Per-model (SnapshotEvaluator): audits within a single snapshot now run concurrently via concurrent_apply_to_values, controlled by concurrent_tasks. This benefits both plan/apply and audit-only runs. 2. Cross-model (Scheduler): when audit_only=True, all audit tasks across all snapshots are flattened into a single thread pool instead of following DAG ordering. Since audits are read-only SELECT queries with no side effects, DAG dependencies are irrelevant and all concurrent_tasks slots stay filled. The SnapshotEvaluator parameter ddl_concurrent_tasks is renamed to concurrent_tasks to reflect its broader scope. Closes #5468 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d8d653f commit 0b56d4d

File tree

5 files changed

+623
-134
lines changed

5 files changed

+623
-134
lines changed

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def snapshot_evaluator(self) -> SnapshotEvaluator:
492492
gateway: adapter.with_settings(execute_log_level=logging.INFO)
493493
for gateway, adapter in self.engine_adapters.items()
494494
},
495-
ddl_concurrent_tasks=self.concurrent_tasks,
495+
concurrent_tasks=self.concurrent_tasks,
496496
selected_gateway=self.selected_gateway,
497497
)
498498
return self._snapshot_evaluator

sqlmesh/core/scheduler.py

Lines changed: 210 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
import abc
44
import logging
5+
import threading
56
import typing as t
67
import time
78
from datetime import datetime
@@ -37,7 +38,7 @@
3738
)
3839
from sqlmesh.core.state_sync import StateSync
3940
from sqlmesh.utils import CompletionStatus
40-
from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError
41+
from sqlmesh.utils.concurrency import concurrent_apply_to_dag, concurrent_apply_to_values, NodeExecutionFailedError
4142
from sqlmesh.utils.dag import DAG
4243
from sqlmesh.utils.date import (
4344
TimeLike,
@@ -499,110 +500,92 @@ def run_merged_intervals(
499500
selected_models=selected_models,
500501
)
501502

502-
# We only need to create physical tables if the snapshot is not representative or if it
503-
# needs backfill
504-
snapshots_to_create_candidates = [
505-
s
506-
for s in selected_snapshots
507-
if not deployability_index.is_representative(s) or s in batched_intervals
508-
]
509-
snapshots_to_create = {
510-
s.snapshot_id
511-
for s in self.snapshot_evaluator.get_snapshots_to_create(
512-
snapshots_to_create_candidates, deployability_index
513-
)
514-
}
515-
516-
dag = self._dag(
517-
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
518-
)
519-
520-
def run_node(node: SchedulingUnit) -> None:
521-
if circuit_breaker and circuit_breaker():
522-
raise CircuitBreakerError()
523-
if isinstance(node, DummyNode):
524-
return
525-
526-
snapshot = self.snapshots_by_name[node.snapshot_name]
527-
528-
if isinstance(node, EvaluateNode):
529-
self.console.start_snapshot_evaluation_progress(snapshot)
530-
execution_start_ts = now_timestamp()
531-
evaluation_duration_ms: t.Optional[int] = None
532-
start, end = node.interval
533-
534-
audit_results: t.List[AuditResult] = []
535-
try:
536-
assert execution_time # mypy
537-
assert deployability_index # mypy
538-
539-
if audit_only:
540-
audit_results = self._audit_snapshot(
541-
snapshot=snapshot,
542-
environment_naming_info=environment_naming_info,
543-
deployability_index=deployability_index,
544-
snapshots=self.snapshots_by_name,
545-
start=start,
546-
end=end,
547-
execution_time=execution_time,
548-
)
549-
else:
550-
# If batch_index > 0, then the target table must exist since the first batch would have created it
551-
target_table_exists = (
552-
snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0
553-
)
554-
audit_results = self.evaluate(
555-
snapshot=snapshot,
556-
environment_naming_info=environment_naming_info,
557-
start=start,
558-
end=end,
559-
execution_time=execution_time,
560-
deployability_index=deployability_index,
561-
batch_index=node.batch_index,
562-
allow_destructive_snapshots=allow_destructive_snapshots,
563-
allow_additive_snapshots=allow_additive_snapshots,
564-
target_table_exists=target_table_exists,
565-
selected_models=selected_models,
503+
try:
504+
with self.snapshot_evaluator.concurrent_context():
505+
if audit_only:
506+
errors, skipped_intervals = self._run_audits_concurrently(
507+
batched_intervals=batched_intervals,
508+
deployability_index=deployability_index,
509+
environment_naming_info=environment_naming_info,
510+
execution_time=execution_time,
511+
circuit_breaker=circuit_breaker,
512+
auto_restatement_triggers=auto_restatement_triggers,
513+
)
514+
else:
515+
# We only need to create physical tables if the snapshot is not representative
516+
# or if it needs backfill
517+
snapshots_to_create_candidates = [
518+
s
519+
for s in selected_snapshots
520+
if not deployability_index.is_representative(s) or s in batched_intervals
521+
]
522+
snapshots_to_create = {
523+
s.snapshot_id
524+
for s in self.snapshot_evaluator.get_snapshots_to_create(
525+
snapshots_to_create_candidates, deployability_index
566526
)
527+
}
567528

568-
evaluation_duration_ms = now_timestamp() - execution_start_ts
569-
finally:
570-
num_audits = len(audit_results)
571-
num_audits_failed = sum(1 for result in audit_results if result.count)
572-
573-
execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
574-
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index)
529+
dag = self._dag(
530+
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
575531
)
576532

577-
self.console.update_snapshot_evaluation_progress(
578-
snapshot,
579-
batched_intervals[snapshot][node.batch_index],
580-
node.batch_index,
581-
evaluation_duration_ms,
582-
num_audits - num_audits_failed,
583-
num_audits_failed,
584-
execution_stats=execution_stats,
585-
auto_restatement_triggers=auto_restatement_triggers.get(
586-
snapshot.snapshot_id
587-
),
533+
def run_node(node: SchedulingUnit) -> None:
534+
if circuit_breaker and circuit_breaker():
535+
raise CircuitBreakerError()
536+
if isinstance(node, DummyNode):
537+
return
538+
539+
snapshot = self.snapshots_by_name[node.snapshot_name]
540+
541+
if isinstance(node, EvaluateNode):
542+
assert execution_time # mypy
543+
assert deployability_index # mypy
544+
node_start, node_end = node.interval
545+
546+
# If batch_index > 0, then the target table must exist since the first batch would have created it
547+
target_table_exists = (
548+
snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0
549+
)
550+
551+
def _do_evaluate() -> t.List[AuditResult]:
552+
return self.evaluate(
553+
snapshot=snapshot,
554+
environment_naming_info=environment_naming_info,
555+
start=node_start,
556+
end=node_end,
557+
execution_time=execution_time,
558+
deployability_index=deployability_index,
559+
batch_index=node.batch_index,
560+
allow_destructive_snapshots=allow_destructive_snapshots,
561+
allow_additive_snapshots=allow_additive_snapshots,
562+
target_table_exists=target_table_exists,
563+
selected_models=selected_models,
564+
)
565+
566+
self._run_node_with_progress(
567+
snapshot=snapshot,
568+
node=node,
569+
batched_intervals=batched_intervals,
570+
auto_restatement_triggers=auto_restatement_triggers,
571+
work_fn=_do_evaluate,
572+
)
573+
elif isinstance(node, CreateNode):
574+
self.snapshot_evaluator.create_snapshot(
575+
snapshot=snapshot,
576+
snapshots=self.snapshots_by_name,
577+
deployability_index=deployability_index,
578+
allow_destructive_snapshots=allow_destructive_snapshots or set(),
579+
allow_additive_snapshots=allow_additive_snapshots or set(),
580+
)
581+
582+
errors, skipped_intervals = concurrent_apply_to_dag(
583+
dag,
584+
run_node,
585+
self.max_workers,
586+
raise_on_error=False,
588587
)
589-
elif isinstance(node, CreateNode):
590-
self.snapshot_evaluator.create_snapshot(
591-
snapshot=snapshot,
592-
snapshots=self.snapshots_by_name,
593-
deployability_index=deployability_index,
594-
allow_destructive_snapshots=allow_destructive_snapshots or set(),
595-
allow_additive_snapshots=allow_additive_snapshots or set(),
596-
)
597588

598-
try:
599-
with self.snapshot_evaluator.concurrent_context():
600-
errors, skipped_intervals = concurrent_apply_to_dag(
601-
dag,
602-
run_node,
603-
self.max_workers,
604-
raise_on_error=False,
605-
)
606589
self.console.stop_evaluation_progress(success=not errors)
607590

608591
skipped_snapshots = {
@@ -947,6 +930,134 @@ def _audit_snapshot(
947930

948931
return audit_results
949932

933+
def _run_node_with_progress(
934+
self,
935+
*,
936+
snapshot: Snapshot,
937+
node: EvaluateNode,
938+
batched_intervals: t.Dict[Snapshot, Intervals],
939+
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]],
940+
work_fn: t.Callable[[], t.List[AuditResult]],
941+
) -> None:
942+
"""Runs a work function for a node while tracking progress and audit results.
943+
944+
Args:
945+
snapshot: The snapshot being processed.
946+
node: The evaluate node.
947+
batched_intervals: The batched intervals per snapshot.
948+
auto_restatement_triggers: Auto restatement trigger info per snapshot.
949+
work_fn: A callable that performs the actual work and returns audit results.
950+
"""
951+
self.console.start_snapshot_evaluation_progress(snapshot)
952+
execution_start_ts = now_timestamp()
953+
evaluation_duration_ms: t.Optional[int] = None
954+
955+
audit_results: t.List[AuditResult] = []
956+
try:
957+
audit_results = work_fn()
958+
evaluation_duration_ms = now_timestamp() - execution_start_ts
959+
finally:
960+
num_audits = len(audit_results)
961+
num_audits_failed = sum(1 for result in audit_results if result.count)
962+
963+
execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats(
964+
SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index)
965+
)
966+
967+
self.console.update_snapshot_evaluation_progress(
968+
snapshot,
969+
batched_intervals[snapshot][node.batch_index],
970+
node.batch_index,
971+
evaluation_duration_ms,
972+
num_audits - num_audits_failed,
973+
num_audits_failed,
974+
execution_stats=execution_stats,
975+
auto_restatement_triggers=auto_restatement_triggers.get(
976+
snapshot.snapshot_id
977+
),
978+
)
979+
980+
def _run_audits_concurrently(
981+
self,
982+
*,
983+
batched_intervals: t.Dict[Snapshot, Intervals],
984+
deployability_index: DeployabilityIndex,
985+
environment_naming_info: EnvironmentNamingInfo,
986+
execution_time: TimeLike,
987+
circuit_breaker: t.Optional[t.Callable[[], bool]],
988+
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]],
989+
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
990+
"""Runs all audits across all snapshots in a single flat thread pool.
991+
992+
Audits are read-only SELECT queries with no side effects, so they can safely
993+
run concurrently even across snapshots that have DAG dependencies. This fills
994+
all concurrent_tasks slots at once instead of processing level-by-level as the
995+
DAG executor would.
996+
997+
Args:
998+
batched_intervals: The batched intervals to audit per snapshot.
999+
deployability_index: Determines snapshots that are deployable.
1000+
environment_naming_info: The environment naming info.
1001+
execution_time: The date/time reference to use for execution time.
1002+
circuit_breaker: An optional handler which checks if the run should be aborted.
1003+
auto_restatement_triggers: Auto restatement trigger info per snapshot.
1004+
1005+
Returns:
1006+
A tuple of errors and skipped intervals (always empty for audit-only runs).
1007+
"""
1008+
# Flatten all (snapshot, interval, batch_index) tasks across all snapshots
1009+
audit_tasks: t.List[EvaluateNode] = [
1010+
EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=batch_index)
1011+
for snapshot, intervals in batched_intervals.items()
1012+
for batch_index, interval in enumerate(intervals)
1013+
]
1014+
1015+
errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = []
1016+
errors_lock = threading.Lock()
1017+
1018+
def run_audit_task(node: EvaluateNode) -> None:
1019+
# The circuit breaker is checked at task start. Tasks already submitted to the
1020+
# thread pool will run to completion — unlike the DAG executor's level-by-level
1021+
# cancellation, this is acceptable for audit-only runs because audits are
1022+
# read-only and have no side effects.
1023+
if circuit_breaker and circuit_breaker():
1024+
raise CircuitBreakerError()
1025+
1026+
snapshot = self.snapshots_by_name[node.snapshot_name]
1027+
node_start, node_end = node.interval
1028+
1029+
def _do_audit() -> t.List[AuditResult]:
1030+
return self._audit_snapshot(
1031+
snapshot=snapshot,
1032+
environment_naming_info=environment_naming_info,
1033+
deployability_index=deployability_index,
1034+
snapshots=self.snapshots_by_name,
1035+
start=node_start,
1036+
end=node_end,
1037+
execution_time=execution_time,
1038+
)
1039+
1040+
self._run_node_with_progress(
1041+
snapshot=snapshot,
1042+
node=node,
1043+
batched_intervals=batched_intervals,
1044+
auto_restatement_triggers=auto_restatement_triggers,
1045+
work_fn=_do_audit,
1046+
)
1047+
1048+
def run_audit_task_collecting_errors(node: EvaluateNode) -> None:
1049+
try:
1050+
run_audit_task(node)
1051+
except Exception as ex:
1052+
error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node)
1053+
error.__cause__ = ex
1054+
with errors_lock:
1055+
errors.append(error)
1056+
1057+
concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers)
1058+
1059+
return errors, []
1060+
9501061
def _check_ready_intervals(
9511062
self,
9521063
snapshot: Snapshot,

0 commit comments

Comments
 (0)