Skip to content

Commit 69bc602

Browse files
authored
Feat!: Add the ability to control concurrency between model batches during the evaluation (#2450)
1 parent deb8e9a commit 69bc602

File tree

13 files changed

+205
-42
lines changed

13 files changed

+205
-42
lines changed

docs/concepts/models/overview.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ For models that are incremental, the following parameters can be specified in th
250250
### batch_size
251251
- Batch size is used to optimize backfilling incremental data. It determines the maximum number of intervals to run in a single job. For example, if a model specifies a cron of `@hourly` and a batch_size of `12`, when backfilling 3 days of data, the scheduler will spawn 6 jobs. (3 days * 24 hours/day = 72 hour intervals to fill. 72 intervals / 12 intervals per job = 6 jobs.)
252252

253+
### batch_concurrency
254+
- The maximum number of [batches](#batch_size) that can run concurrently for this model. If not specified, the concurrency is only constrained by the number of concurrent tasks set in the connection settings.
255+
253256
### forward_only
254257
- Set this to true to indicate that all changes to this model should be [forward-only](../plans.md#forward-only-plans).
255258

docs/reference/model_configuration.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Configuration options for SQLMesh model properties. Supported by all model kinds
2121
| `interval_unit` | The temporal granularity of the model's data intervals. Supported values: `year`, `month`, `day`, `hour`, `half_hour`, `quarter_hour`, `five_minute`. (Default: inferred from `cron`) | str | N |
2222
| `start` | The date/time that determines the earliest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. | str \| int | N |
2323
| `end` | The date/time that determines the latest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. | str \| int | N |
24-
| `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N |
2524
| `grains` | The column(s) whose combination uniquely identifies each row in the model | str \| array[str] | N |
2625
| `references` | The model column(s) used to join to other models' grains | str \| array[str] | N |
2726
| `depends_on` | Models on which this model depends. (Default: dependencies inferred from model code) | array[str] | N |
@@ -45,7 +44,6 @@ The SQLMesh project-level `model_defaults` key supports the following options, d
4544
- owner
4645
- start
4746
- end
48-
- batch_size
4947
- storage_format
5048

5149
## Model kind properties
@@ -74,10 +72,11 @@ Python model configuration object: [FullKind()](https://sqlmesh.readthedocs.io/e
7472

7573
Configuration options for all incremental models (in addition to [general model properties](#general-model-properties)).
7674

77-
| Option | Description | Type | Required |
78-
| ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--: | :------: |
79-
| `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N |
80-
| `lookback` | The number of time unit intervals prior to the current interval that should be processed. (Default: `0`) | int | N |
75+
| Option | Description | Type | Required |
76+
|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:|
77+
| `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N |
78+
| `batch_concurrency` | The maximum number of batches that can run concurrently for this model (Default: the number of concurrent tasks set in the connection settings). | int | N |
79+
| `lookback` | The number of time unit intervals prior to the current interval that should be processed. (Default: `0`) | int | N |
8180

8281
#### Incremental by time range
8382

@@ -172,4 +171,4 @@ Options specified within the `kind` property's `csv_settings` property (override
172171
| `lineterminator` | Character used to denote a line break. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
173172
| `encoding` | Encoding to use for UTF when reading/writing (ex. 'utf-8'). More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
174173

175-
Python model configuration object: [SeedKind()](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#SeedKind)
174+
Python model configuration object: [SeedKind()](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#SeedKind)

sqlmesh/core/config/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ class ModelDefaultsConfig(BaseConfig):
1919
start: The earliest date that the model will be backfilled for. If this is None,
2020
then the date is inferred by taking the most recent start date of its ancestors.
2121
The start date can be a static datetime or a relative datetime like "1 year ago"
22-
batch_size: The maximum number of intervals that can be run per backfill job. If this is None,
23-
then backfilling this model will do all of history in one job. If this is set, a model's backfill
24-
will be chunked such that each individual job will only contain jobs with max `batch_size` intervals.
2522
storage_format: The storage format used to store the physical table, only applicable in certain engines.
2623
(eg. 'parquet')
2724
"""
@@ -31,7 +28,6 @@ class ModelDefaultsConfig(BaseConfig):
3128
cron: t.Optional[str] = None
3229
owner: t.Optional[str] = None
3330
start: t.Optional[TimeLike] = None
34-
batch_size: t.Optional[int] = None
3531
storage_format: t.Optional[str] = None
3632

3733
_model_kind_validator = model_kind_validator

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ class _Model(ModelMeta, frozen=True):
7373
name sushi.order_items,
7474
owner jen,
7575
cron '@daily',
76-
batch_size 30,
7776
start '2020-01-01',
7877
partitioned_by ds
7978
);
@@ -101,9 +100,6 @@ class _Model(ModelMeta, frozen=True):
101100
The start date can be a static datetime or a relative datetime like "1 year ago"
102101
end: The date that the model will be backfilled up until. Follows the same syntax as 'start',
103102
should be omitted if there is no end date.
104-
batch_size: The maximum number of incremental intervals that can be run per backfill job. If this is None,
105-
then backfilling this model will do all of history in one job. If this is set, a model's backfill
106-
will be chunked such that each individual job will only contain jobs with max `batch_size` intervals.
107103
lookback: The number of previous incremental intervals in the lookback window.
108104
storage_format: The storage format used to store the physical table, only applicable in certain engines.
109105
(eg. 'parquet')
@@ -750,6 +746,7 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str:
750746
str(self.end) if self.end else None,
751747
str(self.retention) if self.retention else None,
752748
str(self.batch_size) if self.batch_size is not None else None,
749+
str(self.batch_concurrency) if self.batch_concurrency is not None else None,
753750
json.dumps(self.mapping_schema, sort_keys=True),
754751
*sorted(self.tags),
755752
*sorted(ref.json(sort_keys=True) for ref in self.all_references),
@@ -2024,7 +2021,6 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
20242021
META_FIELD_CONVERTER: t.Dict[str, t.Callable] = {
20252022
"start": lambda value: exp.Literal.string(value),
20262023
"cron": lambda value: exp.Literal.string(value),
2027-
"batch_size": lambda value: exp.Literal.number(value),
20282024
"partitioned_by_": _single_expr_or_tuple,
20292025
"clustered_by": _single_value_or_tuple,
20302026
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),

sqlmesh/core/model/kind.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def to_property(self, dialect: str = "") -> exp.Property:
254254
class _Incremental(_ModelKind):
255255
dialect: str = ""
256256
batch_size: t.Optional[SQLGlotPositiveInt] = None
257+
batch_concurrency: t.Optional[SQLGlotPositiveInt] = None
257258
lookback: t.Optional[SQLGlotPositiveInt] = None
258259
forward_only: SQLGlotBool = False
259260
disable_restatement: SQLGlotBool = False
@@ -303,6 +304,7 @@ class IncrementalByUniqueKeyKind(_Incremental):
303304
name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
304305
unique_key: SQLGlotListOfFields
305306
when_matched: t.Optional[exp.When] = None
307+
batch_concurrency: Literal[1] = 1
306308

307309
@field_validator("when_matched", mode="before")
308310
@field_validator_v1_args

sqlmesh/core/model/meta.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ def batch_size(self) -> t.Optional[int]:
315315
"""The maximal number of units in a single task for a backfill."""
316316
return getattr(self.kind, "batch_size", None)
317317

318+
@property
319+
def batch_concurrency(self) -> t.Optional[int]:
320+
"""The maximal number of batches that can run concurrently for a backfill."""
321+
return getattr(self.kind, "batch_concurrency", None)
322+
318323
@cached_property
319324
def table_properties(self) -> t.Dict[str, exp.Expression]:
320325
"""A dictionary of table properties."""

sqlmesh/core/node.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ def batch_size(self) -> t.Optional[int]:
270270
"""The maximal number of units in a single task for a backfill."""
271271
return None
272272

273+
@property
274+
def batch_concurrency(self) -> t.Optional[int]:
275+
"""The maximal number of batches that can run concurrently for a backfill."""
276+
return None
277+
273278
@property
274279
def data_hash(self) -> str:
275280
"""

sqlmesh/core/scheduler.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,28 @@ def _dag(self, batches: SnapshotToBatches) -> DAG[SchedulingUnit]:
361361
for i, interval in enumerate(p_intervals):
362362
upstream_dependencies.append((p_sid.name, (interval, i)))
363363

364+
batch_concurrency = snapshot.node.batch_concurrency
365+
if snapshot.depends_on_past:
366+
batch_concurrency = 1
367+
364368
for i, interval in enumerate(intervals):
365369
node = (snapshot.name, (interval, i))
366370
dag.add(node, upstream_dependencies)
367371

368372
if len(intervals) > 1:
369373
dag.add((snapshot.name, terminal_node), [node])
370374

371-
if snapshot.depends_on_past and i > 0:
372-
dag.add(node, [(snapshot.name, (intervals[i - 1], i - 1))])
375+
if batch_concurrency and i >= batch_concurrency:
376+
batch_idx_to_wait_for = i - batch_concurrency
377+
dag.add(
378+
node,
379+
[
380+
(
381+
snapshot.name,
382+
(intervals[batch_idx_to_wait_for], batch_idx_to_wait_for),
383+
)
384+
],
385+
)
373386
return dag
374387

375388

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Add the batch_concurrency attribute to the incremental model kinds.
2+
3+
This results in a change to the metadata hash.
4+
"""
5+
6+
7+
def migrate(state_sync, **kwargs): # type: ignore
8+
pass

sqlmesh/schedulers/airflow/dag_generator.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pendulum
88
from airflow import DAG
9-
from airflow.models import BaseOperator, baseoperator
9+
from airflow.models import BaseOperator
1010
from airflow.operators.python import PythonOperator
1111
from airflow.sensors.base import BaseSensorOperator
1212

@@ -437,7 +437,7 @@ def _create_backfill_tasks(
437437
snapshot = snapshots[sid]
438438
sanitized_model_name = sanitize_name(snapshot.node.name)
439439

440-
snapshot_intervals_chain: t.List[t.Union[BaseOperator, t.List[BaseOperator]]] = []
440+
snapshot_task_pairs: t.List[t.Tuple[BaseOperator, BaseOperator]] = []
441441

442442
snapshot_start_task = EmptyOperator(
443443
task_id=f"snapshot_backfill__{sanitized_model_name}__{snapshot.identifier}__start"
@@ -457,32 +457,30 @@ def _create_backfill_tasks(
457457
deployability_index=deployability_index,
458458
plan_id=plan_id,
459459
)
460-
461460
external_sensor_task = self._create_hwm_external_sensor(
462461
snapshot, start=start, end=end
463462
)
464463
if external_sensor_task:
465-
if snapshot.depends_on_past:
466-
snapshot_intervals_chain.extend([external_sensor_task, evaluation_task])
467-
else:
468-
(
469-
snapshot_start_task
470-
>> external_sensor_task
471-
>> evaluation_task
472-
>> snapshot_end_task
473-
)
464+
(
465+
snapshot_start_task
466+
>> external_sensor_task
467+
>> evaluation_task
468+
>> snapshot_end_task
469+
)
470+
snapshot_task_pairs.append((external_sensor_task, evaluation_task))
474471
else:
475-
if snapshot.depends_on_past:
476-
snapshot_intervals_chain.append(evaluation_task)
477-
else:
478-
snapshot_start_task >> evaluation_task >> snapshot_end_task
472+
snapshot_start_task >> evaluation_task >> snapshot_end_task
473+
snapshot_task_pairs.append((evaluation_task, evaluation_task))
479474

475+
batch_concurrency = snapshot.node.batch_concurrency
480476
if snapshot.depends_on_past:
481-
baseoperator.chain(
482-
snapshot_start_task, *snapshot_intervals_chain, snapshot_end_task
483-
)
484-
elif not intervals_per_snapshot.intervals:
477+
batch_concurrency = 1
478+
479+
if not intervals_per_snapshot.intervals:
485480
snapshot_start_task >> snapshot_end_task
481+
elif batch_concurrency:
482+
for i in range(batch_concurrency, len(snapshot_task_pairs)):
483+
snapshot_task_pairs[i - batch_concurrency][1] >> snapshot_task_pairs[i][0]
486484

487485
snapshot_to_tasks[snapshot.snapshot_id] = (
488486
snapshot_start_task,

0 commit comments

Comments
 (0)