|
2 | 2 | from dataclasses import dataclass |
3 | 3 | import abc |
4 | 4 | import logging |
| 5 | +import threading |
5 | 6 | import typing as t |
6 | 7 | import time |
7 | 8 | from datetime import datetime |
|
37 | 38 | ) |
38 | 39 | from sqlmesh.core.state_sync import StateSync |
39 | 40 | from sqlmesh.utils import CompletionStatus |
40 | | -from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError |
| 41 | +from sqlmesh.utils.concurrency import concurrent_apply_to_dag, concurrent_apply_to_values, NodeExecutionFailedError |
41 | 42 | from sqlmesh.utils.dag import DAG |
42 | 43 | from sqlmesh.utils.date import ( |
43 | 44 | TimeLike, |
@@ -499,110 +500,92 @@ def run_merged_intervals( |
499 | 500 | selected_models=selected_models, |
500 | 501 | ) |
501 | 502 |
|
502 | | - # We only need to create physical tables if the snapshot is not representative or if it |
503 | | - # needs backfill |
504 | | - snapshots_to_create_candidates = [ |
505 | | - s |
506 | | - for s in selected_snapshots |
507 | | - if not deployability_index.is_representative(s) or s in batched_intervals |
508 | | - ] |
509 | | - snapshots_to_create = { |
510 | | - s.snapshot_id |
511 | | - for s in self.snapshot_evaluator.get_snapshots_to_create( |
512 | | - snapshots_to_create_candidates, deployability_index |
513 | | - ) |
514 | | - } |
515 | | - |
516 | | - dag = self._dag( |
517 | | - batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create |
518 | | - ) |
519 | | - |
520 | | - def run_node(node: SchedulingUnit) -> None: |
521 | | - if circuit_breaker and circuit_breaker(): |
522 | | - raise CircuitBreakerError() |
523 | | - if isinstance(node, DummyNode): |
524 | | - return |
525 | | - |
526 | | - snapshot = self.snapshots_by_name[node.snapshot_name] |
527 | | - |
528 | | - if isinstance(node, EvaluateNode): |
529 | | - self.console.start_snapshot_evaluation_progress(snapshot) |
530 | | - execution_start_ts = now_timestamp() |
531 | | - evaluation_duration_ms: t.Optional[int] = None |
532 | | - start, end = node.interval |
533 | | - |
534 | | - audit_results: t.List[AuditResult] = [] |
535 | | - try: |
536 | | - assert execution_time # mypy |
537 | | - assert deployability_index # mypy |
538 | | - |
539 | | - if audit_only: |
540 | | - audit_results = self._audit_snapshot( |
541 | | - snapshot=snapshot, |
542 | | - environment_naming_info=environment_naming_info, |
543 | | - deployability_index=deployability_index, |
544 | | - snapshots=self.snapshots_by_name, |
545 | | - start=start, |
546 | | - end=end, |
547 | | - execution_time=execution_time, |
548 | | - ) |
549 | | - else: |
550 | | - # If batch_index > 0, then the target table must exist since the first batch would have created it |
551 | | - target_table_exists = ( |
552 | | - snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 |
553 | | - ) |
554 | | - audit_results = self.evaluate( |
555 | | - snapshot=snapshot, |
556 | | - environment_naming_info=environment_naming_info, |
557 | | - start=start, |
558 | | - end=end, |
559 | | - execution_time=execution_time, |
560 | | - deployability_index=deployability_index, |
561 | | - batch_index=node.batch_index, |
562 | | - allow_destructive_snapshots=allow_destructive_snapshots, |
563 | | - allow_additive_snapshots=allow_additive_snapshots, |
564 | | - target_table_exists=target_table_exists, |
565 | | - selected_models=selected_models, |
| 503 | + try: |
| 504 | + with self.snapshot_evaluator.concurrent_context(): |
| 505 | + if audit_only: |
| 506 | + errors, skipped_intervals = self._run_audits_concurrently( |
| 507 | + batched_intervals=batched_intervals, |
| 508 | + deployability_index=deployability_index, |
| 509 | + environment_naming_info=environment_naming_info, |
| 510 | + execution_time=execution_time, |
| 511 | + circuit_breaker=circuit_breaker, |
| 512 | + auto_restatement_triggers=auto_restatement_triggers, |
| 513 | + ) |
| 514 | + else: |
| 515 | + # We only need to create physical tables if the snapshot is not representative |
| 516 | + # or if it needs backfill |
| 517 | + snapshots_to_create_candidates = [ |
| 518 | + s |
| 519 | + for s in selected_snapshots |
| 520 | + if not deployability_index.is_representative(s) or s in batched_intervals |
| 521 | + ] |
| 522 | + snapshots_to_create = { |
| 523 | + s.snapshot_id |
| 524 | + for s in self.snapshot_evaluator.get_snapshots_to_create( |
| 525 | + snapshots_to_create_candidates, deployability_index |
566 | 526 | ) |
| 527 | + } |
567 | 528 |
|
568 | | - evaluation_duration_ms = now_timestamp() - execution_start_ts |
569 | | - finally: |
570 | | - num_audits = len(audit_results) |
571 | | - num_audits_failed = sum(1 for result in audit_results if result.count) |
572 | | - |
573 | | - execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( |
574 | | - SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) |
| 529 | + dag = self._dag( |
| 530 | + batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create |
575 | 531 | ) |
576 | 532 |
|
577 | | - self.console.update_snapshot_evaluation_progress( |
578 | | - snapshot, |
579 | | - batched_intervals[snapshot][node.batch_index], |
580 | | - node.batch_index, |
581 | | - evaluation_duration_ms, |
582 | | - num_audits - num_audits_failed, |
583 | | - num_audits_failed, |
584 | | - execution_stats=execution_stats, |
585 | | - auto_restatement_triggers=auto_restatement_triggers.get( |
586 | | - snapshot.snapshot_id |
587 | | - ), |
| 533 | + def run_node(node: SchedulingUnit) -> None: |
| 534 | + if circuit_breaker and circuit_breaker(): |
| 535 | + raise CircuitBreakerError() |
| 536 | + if isinstance(node, DummyNode): |
| 537 | + return |
| 538 | + |
| 539 | + snapshot = self.snapshots_by_name[node.snapshot_name] |
| 540 | + |
| 541 | + if isinstance(node, EvaluateNode): |
| 542 | + assert execution_time # mypy |
| 543 | + assert deployability_index # mypy |
| 544 | + node_start, node_end = node.interval |
| 545 | + |
| 546 | + # If batch_index > 0, then the target table must exist since the first batch would have created it |
| 547 | + target_table_exists = ( |
| 548 | + snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 |
| 549 | + ) |
| 550 | + |
| 551 | + def _do_evaluate() -> t.List[AuditResult]: |
| 552 | + return self.evaluate( |
| 553 | + snapshot=snapshot, |
| 554 | + environment_naming_info=environment_naming_info, |
| 555 | + start=node_start, |
| 556 | + end=node_end, |
| 557 | + execution_time=execution_time, |
| 558 | + deployability_index=deployability_index, |
| 559 | + batch_index=node.batch_index, |
| 560 | + allow_destructive_snapshots=allow_destructive_snapshots, |
| 561 | + allow_additive_snapshots=allow_additive_snapshots, |
| 562 | + target_table_exists=target_table_exists, |
| 563 | + selected_models=selected_models, |
| 564 | + ) |
| 565 | + |
| 566 | + self._run_node_with_progress( |
| 567 | + snapshot=snapshot, |
| 568 | + node=node, |
| 569 | + batched_intervals=batched_intervals, |
| 570 | + auto_restatement_triggers=auto_restatement_triggers, |
| 571 | + work_fn=_do_evaluate, |
| 572 | + ) |
| 573 | + elif isinstance(node, CreateNode): |
| 574 | + self.snapshot_evaluator.create_snapshot( |
| 575 | + snapshot=snapshot, |
| 576 | + snapshots=self.snapshots_by_name, |
| 577 | + deployability_index=deployability_index, |
| 578 | + allow_destructive_snapshots=allow_destructive_snapshots or set(), |
| 579 | + allow_additive_snapshots=allow_additive_snapshots or set(), |
| 580 | + ) |
| 581 | + |
| 582 | + errors, skipped_intervals = concurrent_apply_to_dag( |
| 583 | + dag, |
| 584 | + run_node, |
| 585 | + self.max_workers, |
| 586 | + raise_on_error=False, |
588 | 587 | ) |
589 | | - elif isinstance(node, CreateNode): |
590 | | - self.snapshot_evaluator.create_snapshot( |
591 | | - snapshot=snapshot, |
592 | | - snapshots=self.snapshots_by_name, |
593 | | - deployability_index=deployability_index, |
594 | | - allow_destructive_snapshots=allow_destructive_snapshots or set(), |
595 | | - allow_additive_snapshots=allow_additive_snapshots or set(), |
596 | | - ) |
597 | 588 |
|
598 | | - try: |
599 | | - with self.snapshot_evaluator.concurrent_context(): |
600 | | - errors, skipped_intervals = concurrent_apply_to_dag( |
601 | | - dag, |
602 | | - run_node, |
603 | | - self.max_workers, |
604 | | - raise_on_error=False, |
605 | | - ) |
606 | 589 | self.console.stop_evaluation_progress(success=not errors) |
607 | 590 |
|
608 | 591 | skipped_snapshots = { |
@@ -947,6 +930,134 @@ def _audit_snapshot( |
947 | 930 |
|
948 | 931 | return audit_results |
949 | 932 |
|
| 933 | + def _run_node_with_progress( |
| 934 | + self, |
| 935 | + *, |
| 936 | + snapshot: Snapshot, |
| 937 | + node: EvaluateNode, |
| 938 | + batched_intervals: t.Dict[Snapshot, Intervals], |
| 939 | + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]], |
| 940 | + work_fn: t.Callable[[], t.List[AuditResult]], |
| 941 | + ) -> None: |
| 942 | + """Runs a work function for a node while tracking progress and audit results. |
| 943 | +
|
| 944 | + Args: |
| 945 | + snapshot: The snapshot being processed. |
| 946 | + node: The evaluate node. |
| 947 | + batched_intervals: The batched intervals per snapshot. |
| 948 | + auto_restatement_triggers: Auto restatement trigger info per snapshot. |
| 949 | + work_fn: A callable that performs the actual work and returns audit results. |
| 950 | + """ |
| 951 | + self.console.start_snapshot_evaluation_progress(snapshot) |
| 952 | + execution_start_ts = now_timestamp() |
| 953 | + evaluation_duration_ms: t.Optional[int] = None |
| 954 | + |
| 955 | + audit_results: t.List[AuditResult] = [] |
| 956 | + try: |
| 957 | + audit_results = work_fn() |
| 958 | + evaluation_duration_ms = now_timestamp() - execution_start_ts |
| 959 | + finally: |
| 960 | + num_audits = len(audit_results) |
| 961 | + num_audits_failed = sum(1 for result in audit_results if result.count) |
| 962 | + |
| 963 | + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( |
| 964 | + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) |
| 965 | + ) |
| 966 | + |
| 967 | + self.console.update_snapshot_evaluation_progress( |
| 968 | + snapshot, |
| 969 | + batched_intervals[snapshot][node.batch_index], |
| 970 | + node.batch_index, |
| 971 | + evaluation_duration_ms, |
| 972 | + num_audits - num_audits_failed, |
| 973 | + num_audits_failed, |
| 974 | + execution_stats=execution_stats, |
| 975 | + auto_restatement_triggers=auto_restatement_triggers.get( |
| 976 | + snapshot.snapshot_id |
| 977 | + ), |
| 978 | + ) |
| 979 | + |
| 980 | + def _run_audits_concurrently( |
| 981 | + self, |
| 982 | + *, |
| 983 | + batched_intervals: t.Dict[Snapshot, Intervals], |
| 984 | + deployability_index: DeployabilityIndex, |
| 985 | + environment_naming_info: EnvironmentNamingInfo, |
| 986 | + execution_time: TimeLike, |
| 987 | + circuit_breaker: t.Optional[t.Callable[[], bool]], |
| 988 | + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]], |
| 989 | + ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: |
| 990 | + """Runs all audits across all snapshots in a single flat thread pool. |
| 991 | +
|
| 992 | + Audits are read-only SELECT queries with no side effects, so they can safely |
| 993 | + run concurrently even across snapshots that have DAG dependencies. This fills |
| 994 | + all concurrent_tasks slots at once instead of processing level-by-level as the |
| 995 | + DAG executor would. |
| 996 | +
|
| 997 | + Args: |
| 998 | + batched_intervals: The batched intervals to audit per snapshot. |
| 999 | + deployability_index: Determines snapshots that are deployable. |
| 1000 | + environment_naming_info: The environment naming info. |
| 1001 | + execution_time: The date/time reference to use for execution time. |
| 1002 | + circuit_breaker: An optional handler which checks if the run should be aborted. |
| 1003 | + auto_restatement_triggers: Auto restatement trigger info per snapshot. |
| 1004 | +
|
| 1005 | + Returns: |
| 1006 | + A tuple of errors and skipped intervals (always empty for audit-only runs). |
| 1007 | + """ |
| 1008 | + # Flatten all (snapshot, interval, batch_index) tasks across all snapshots |
| 1009 | + audit_tasks: t.List[EvaluateNode] = [ |
| 1010 | + EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=batch_index) |
| 1011 | + for snapshot, intervals in batched_intervals.items() |
| 1012 | + for batch_index, interval in enumerate(intervals) |
| 1013 | + ] |
| 1014 | + |
| 1015 | + errors: t.List[NodeExecutionFailedError[SchedulingUnit]] = [] |
| 1016 | + errors_lock = threading.Lock() |
| 1017 | + |
| 1018 | + def run_audit_task(node: EvaluateNode) -> None: |
| 1019 | + # The circuit breaker is checked at task start. Tasks already submitted to the |
| 1020 | + # thread pool will run to completion — unlike the DAG executor's level-by-level |
| 1021 | + # cancellation, this is acceptable for audit-only runs because audits are |
| 1022 | + # read-only and have no side effects. |
| 1023 | + if circuit_breaker and circuit_breaker(): |
| 1024 | + raise CircuitBreakerError() |
| 1025 | + |
| 1026 | + snapshot = self.snapshots_by_name[node.snapshot_name] |
| 1027 | + node_start, node_end = node.interval |
| 1028 | + |
| 1029 | + def _do_audit() -> t.List[AuditResult]: |
| 1030 | + return self._audit_snapshot( |
| 1031 | + snapshot=snapshot, |
| 1032 | + environment_naming_info=environment_naming_info, |
| 1033 | + deployability_index=deployability_index, |
| 1034 | + snapshots=self.snapshots_by_name, |
| 1035 | + start=node_start, |
| 1036 | + end=node_end, |
| 1037 | + execution_time=execution_time, |
| 1038 | + ) |
| 1039 | + |
| 1040 | + self._run_node_with_progress( |
| 1041 | + snapshot=snapshot, |
| 1042 | + node=node, |
| 1043 | + batched_intervals=batched_intervals, |
| 1044 | + auto_restatement_triggers=auto_restatement_triggers, |
| 1045 | + work_fn=_do_audit, |
| 1046 | + ) |
| 1047 | + |
| 1048 | + def run_audit_task_collecting_errors(node: EvaluateNode) -> None: |
| 1049 | + try: |
| 1050 | + run_audit_task(node) |
| 1051 | + except Exception as ex: |
| 1052 | + error: NodeExecutionFailedError[SchedulingUnit] = NodeExecutionFailedError(node) |
| 1053 | + error.__cause__ = ex |
| 1054 | + with errors_lock: |
| 1055 | + errors.append(error) |
| 1056 | + |
| 1057 | + concurrent_apply_to_values(audit_tasks, run_audit_task_collecting_errors, self.max_workers) |
| 1058 | + |
| 1059 | + return errors, [] |
| 1060 | + |
950 | 1061 | def _check_ready_intervals( |
951 | 1062 | self, |
952 | 1063 | snapshot: Snapshot, |
|
0 commit comments