diff --git a/airflow-core/docs/authoring-and-scheduling/asset-scheduling.rst b/airflow-core/docs/authoring-and-scheduling/asset-scheduling.rst index 1bfe98d715ec2..a7a2fb6ed06c2 100644 --- a/airflow-core/docs/authoring-and-scheduling/asset-scheduling.rst +++ b/airflow-core/docs/authoring-and-scheduling/asset-scheduling.rst @@ -419,3 +419,29 @@ AssetTimetable Integration You can schedule Dags based on both asset events and time-based schedules using ``AssetOrTimeSchedule``. This allows you to create workflows when a Dag needs both to be triggered by data updates and run periodically according to a fixed timetable. For more detailed information on ``AssetOrTimeSchedule``, refer to the corresponding section in :ref:`AssetOrTimeSchedule `. + + +Controlling DagRun creation per asset event +--------------------------------------------- + +.. versionadded:: 3.3.0 + +By default, when multiple asset events arrive for the same Dag between +scheduler ticks, they are batched into a single DagRun. Set +``batch_asset_events=False`` on the timetable to create one DagRun per +individual event instead. + +.. code-block:: python + + from airflow.sdk import DAG, Asset + from airflow.timetables.simple import AssetTriggeredTimetable + + # Each update to "data-file" produces its own DagRun + with DAG( + dag_id="per-event-consumer", + schedule=AssetTriggeredTimetable( + assets=Asset("s3://bucket/data-file"), + batch_asset_events=False, + ), + ): + ... diff --git a/airflow-core/src/airflow/assets/manager.py b/airflow-core/src/airflow/assets/manager.py index f72c533c5a0e7..b227d8d7411a6 100644 --- a/airflow-core/src/airflow/assets/manager.py +++ b/airflow-core/src/airflow/assets/manager.py @@ -646,6 +646,7 @@ def _queue_partitioned_dags( rollup_fingerprint=fingerprint, asset_id=asset_id, session=session, + allow_reuse=timetable.batch_asset_events, ) log_record = PartitionedAssetKeyLog( asset_id=asset_id, @@ -666,6 +667,7 @@ def _get_or_create_apdr( rollup_fingerprint: dict, asset_id: int, session: Session, + allow_reuse: bool = True, ) -> AssetPartitionDagRun: """ Get or create an APDR. @@ -679,6 +681,12 @@ def _get_or_create_apdr( ``rollup_fingerprint`` is the serialized mapper / window definition for all partitioned assets in the timetable at creation time; the scheduler discards APDRs whose stamp no longer matches the current timetable's fingerprint (mapper / window may have changed). + + When ``allow_reuse=True`` (default), an existing pending APDR for the same + ``(target_dag, partition_key)`` is reused — multiple events accumulate on one + APDR. When ``allow_reuse=False`` (set when the timetable's ``batch_asset_events`` + is ``False``), a new APDR is always created so each event gets its own APDR + and the scheduler produces one DagRun per event. """ with _lock_asset_model(session=session, asset_id=asset_id): latest_apdr: AssetPartitionDagRun | None = session.scalar( @@ -690,7 +698,7 @@ def _get_or_create_apdr( .order_by(AssetPartitionDagRun.id.desc()) .limit(1) ) - if latest_apdr and latest_apdr.created_dag_run_id is None: + if latest_apdr and latest_apdr.created_dag_run_id is None and allow_reuse: cls.logger().debug( "Existing APDR found for key %s dag_id %s", target_key, diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3e9e2321bc9c3..cc87d40ee53a6 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -2572,27 +2572,34 @@ def _create_dag_runs_asset_triggered( ) ) - dag_run = dag.create_dagrun( - run_id=DagRun.generate_run_id( - run_type=DagRunType.ASSET_TRIGGERED, logical_date=None, run_after=triggered_date - ), - logical_date=None, - data_interval=None, - run_after=triggered_date, - run_type=DagRunType.ASSET_TRIGGERED, - triggered_by=DagRunTriggeredByType.ASSET, - state=DagRunState.QUEUED, - creating_job_id=self.job.id, - session=session, - ) - stats.incr("asset.triggered_dagruns") - dag_run.consumed_asset_events.extend(asset_events) - self.log.info( - "Created asset-triggered DagRun for '%s': run_id=%s, consumed %d asset events", - dag.dag_id, - dag_run.run_id, - len(asset_events), - ) + # Build the list of (run_after, events) to process: one entry per DagRun to create. + if dag.timetable.batch_asset_events: + event_runs = [(triggered_date, asset_events)] + else: + event_runs = [(timezone.coerce_datetime(ev.timestamp), [ev]) for ev in asset_events] + + for run_after, events in event_runs: + dag_run = dag.create_dagrun( + run_id=DagRun.generate_run_id( + run_type=DagRunType.ASSET_TRIGGERED, logical_date=None, run_after=run_after + ), + logical_date=None, + data_interval=None, + run_after=run_after, + run_type=DagRunType.ASSET_TRIGGERED, + triggered_by=DagRunTriggeredByType.ASSET, + state=DagRunState.QUEUED, + creating_job_id=self.job.id, + session=session, + ) + stats.incr("asset.triggered_dagruns") + dag_run.consumed_asset_events.extend(events) + self.log.info( + "Created asset-triggered DagRun for '%s': run_id=%s, consumed %d asset events", + dag.dag_id, + dag_run.run_id, + len(events), + ) # Delete only consumed ADRQ rows to avoid dropping newly queued events # (e.g. DagRun triggered by asset A while a new event for asset B arrives). diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 59ac29f111ec7..7d65fec12463d 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -363,7 +363,10 @@ def _( @serialize_timetable.register def _(self, timetable: AssetTriggeredTimetable) -> dict[str, Any]: - return {"asset_condition": encode_asset_like(timetable.asset_condition)} + return { + "asset_condition": encode_asset_like(timetable.asset_condition), + "batch_asset_events": timetable.batch_asset_events, + } @serialize_timetable.register def _(self, timetable: EventsTimetable) -> dict[str, Any]: @@ -434,6 +437,7 @@ def _(self, timetable: CoreTimetable) -> dict[str, Any]: def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: return { "asset_condition": encode_asset_like(timetable.asset_condition), + "batch_asset_events": timetable.batch_asset_events, "default_partition_mapper": encode_partition_mapper(timetable.default_partition_mapper), "partition_mapper_config": [ (encode_asset_like(asset), encode_partition_mapper(partition_mapper)) diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 496ba9da156ed..e268bc0fbb3d1 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -216,8 +216,11 @@ class AssetTriggeredTimetable(_TrivialTimetable): description: str = "Triggered by assets" - def __init__(self, assets: Collection[SerializedAsset] | SerializedAssetBase) -> None: + def __init__( + self, assets: Collection[SerializedAsset] | SerializedAssetBase, *, batch_asset_events: bool = True + ) -> None: super().__init__() + self.batch_asset_events = batch_asset_events # Compatibility: Handle SDK assets if needed so this class works in dag files. if isinstance(assets, SerializedAssetBase | BaseAsset): self.asset_condition = ensure_serialized_asset(assets) @@ -228,14 +231,20 @@ def __init__(self, assets: Collection[SerializedAsset] | SerializedAssetBase) -> def deserialize(cls, data: dict[str, Any]) -> Timetable: from airflow.serialization.decoders import decode_asset_like - return cls(decode_asset_like(data["asset_condition"])) + return cls( + decode_asset_like(data["asset_condition"]), + batch_asset_events=data.get("batch_asset_events", True), + ) @property def summary(self) -> str: return "Asset" def serialize(self) -> dict[str, Any]: - return {"asset_condition": encode_asset_like(self.asset_condition)} + return { + "asset_condition": encode_asset_like(self.asset_condition), + "batch_asset_events": self.batch_asset_events, + } def generate_run_id( self, @@ -283,10 +292,11 @@ def __init__( self, *, assets: SerializedAssetBase, + batch_asset_events: bool = True, partition_mapper_config: dict[SerializedAssetBase, PartitionMapper] | None = None, default_partition_mapper: PartitionMapper = DEFAULT_PARTITION_MAPPER, ) -> None: - super().__init__(assets=assets) + super().__init__(assets=assets, batch_asset_events=batch_asset_events) self.partition_mapper_config = partition_mapper_config or {} self.default_partition_mapper = default_partition_mapper @@ -360,6 +370,7 @@ def serialize(self) -> dict[str, Any]: return { "asset_condition": encode_asset_like(self.asset_condition), + "batch_asset_events": self.batch_asset_events, "partition_mapper_config": [ (encode_asset_like(asset), encode_partition_mapper(partition_mapper)) for asset, partition_mapper in self.partition_mapper_config.items() @@ -377,6 +388,7 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionedAssetTimetable: timetable = cls( assets=decode_asset_like(data["asset_condition"]), + batch_asset_events=data.get("batch_asset_events", True), default_partition_mapper=decode_partition_mapper(default_partition_mapper_data), partition_mapper_config={ decode_asset_like(ser_asest): decode_partition_mapper(ser_partition_mapper) diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 687f4d178479d..1a969da7b04f9 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -275,6 +275,86 @@ def _get_or_create_apdr(): assert len(set(ids)) == 1 assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 1 + @pytest.mark.usefixtures("testing_dag_bundle") + def test_get_or_create_apdr_allow_reuse_true_reuses_pending(self, session): + """``allow_reuse=True`` (default) reuses a pending APDR for the same ``(target_dag, partition_key)``. + + When two events arrive for the same key and ``allow_reuse=True``, the + second call returns the same APDR — they accumulate on one row. + """ + clear_db_apdr() + clear_db_pakl() + asm = AssetModel(uri="test://reuse-true/", name="reuse_asset_true", group="asset") + testing_dag = DagModel(dag_id="reuse_test_dag_true", is_stale=False, bundle_name="testing") + session.add_all([asm, testing_dag]) + session.commit() + session.flush() + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0 + + rollup_fingerprint = {} + + apdr_1 = AssetManager._get_or_create_apdr( + target_key="key-1", + target_dag=testing_dag, + rollup_fingerprint=rollup_fingerprint, + asset_id=asm.id, + session=session, + allow_reuse=True, + ) + apdr_2 = AssetManager._get_or_create_apdr( + target_key="key-1", + target_dag=testing_dag, + rollup_fingerprint=rollup_fingerprint, + asset_id=asm.id, + session=session, + allow_reuse=True, + ) + + assert apdr_1.id == apdr_2.id + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 1 + assert apdr_1.created_dag_run_id is None # still pending + + @pytest.mark.usefixtures("testing_dag_bundle") + def test_get_or_create_apdr_allow_reuse_false_creates_new(self, session): + """``allow_reuse=False`` creates a new APDR each call even if a pending one exists for the same key. + + When two events arrive for the same key and ``allow_reuse=False``, each + event gets its own APDR — the scheduler later produces one DagRun per + event. + """ + clear_db_apdr() + clear_db_pakl() + asm = AssetModel(uri="test://reuse-false/", name="reuse_asset_false", group="asset") + testing_dag = DagModel(dag_id="reuse_test_dag_false", is_stale=False, bundle_name="testing") + session.add_all([asm, testing_dag]) + session.commit() + session.flush() + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 0 + + rollup_fingerprint = {} + + apdr_1 = AssetManager._get_or_create_apdr( + target_key="key-1", + target_dag=testing_dag, + rollup_fingerprint=rollup_fingerprint, + asset_id=asm.id, + session=session, + allow_reuse=False, + ) + apdr_2 = AssetManager._get_or_create_apdr( + target_key="key-1", + target_dag=testing_dag, + rollup_fingerprint=rollup_fingerprint, + asset_id=asm.id, + session=session, + allow_reuse=False, + ) + + assert apdr_1.id != apdr_2.id + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 2 + assert apdr_1.created_dag_run_id is None + assert apdr_2.created_dag_run_id is None + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("testing_dag_bundle") def test_queue_partitioned_dags_stamps_rollup_fingerprint(self, session, dag_maker): diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 67986c0297ee4..3040115399301 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -132,6 +132,7 @@ from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.timetables.base import DagRunInfo, DataInterval, compute_rollup_fingerprint from airflow.timetables.simple import ( + AssetTriggeredTimetable, PartitionAtRuntime, PartitionedAssetTimetable as CorePartitionedAssetTimetable, ) @@ -10033,6 +10034,286 @@ def _produce_and_register_asset_event( return apdr +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_batch_asset_events_true_single_dagrun(dag_maker: DagMaker, session: Session): + """batch_asset_events=True (default): APDR reuse produces one DagRun for all events. + + Two events for the same partition key share one APDR. The scheduler + creates a single DagRun consuming both events. + """ + asset_1 = Asset(name="asset-batch-true") + + # Consumer Dag with batch_asset_events=True (default). + with dag_maker( + dag_id="batch-true-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=IdentityMapper(), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # Two events, same partition key → same APDR (reuse). + apdr = _produce_and_register_asset_event( + dag_id="batch-true-producer-1", + asset=asset_1, + partition_key="key-1", + session=session, + dag_maker=dag_maker, + ) + _produce_and_register_asset_event( + dag_id="batch-true-producer-2", + asset=asset_1, + partition_key="key-1", + session=session, + dag_maker=dag_maker, + ) + + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 1 + + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + assert len(partition_dags) == 1 + assert partition_dags == {"batch-true-consumer"} + + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + dag_run = session.scalar(select(DagRun).where(DagRun.id == apdr.created_dag_run_id)) + assert dag_run is not None + assert len(dag_run.consumed_asset_events) == 2 + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_batch_asset_events_false_one_dagrun_per_event(dag_maker: DagMaker, session: Session): + """batch_asset_events=False: each event gets its own APDR → one DagRun per event. + + Two events for the same partition key produce two APDRs (no reuse). + The scheduler creates two DagRuns, one per event. + """ + asset_1 = Asset(name="asset-batch-false") + + # Consumer Dag with batch_asset_events=False. + with dag_maker( + dag_id="batch-false-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=IdentityMapper(), + batch_asset_events=False, + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # Two events, same partition key → two APDRs (no reuse because batch_asset_events=False). + apdr_1 = _produce_and_register_asset_event( + dag_id="batch-false-producer-1", + asset=asset_1, + partition_key="key-1", + session=session, + dag_maker=dag_maker, + ) + apdr_2 = _produce_and_register_asset_event( + dag_id="batch-false-producer-2", + asset=asset_1, + partition_key="key-1", + session=session, + dag_maker=dag_maker, + ) + + assert apdr_1.id != apdr_2.id + assert session.scalar(select(func.count()).select_from(AssetPartitionDagRun)) == 2 + + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + assert len(partition_dags) == 1 + assert partition_dags == {"batch-false-consumer"} + + # Both APDRs should now have a DagRun. + session.refresh(apdr_1) + session.refresh(apdr_2) + assert apdr_1.created_dag_run_id is not None + assert apdr_2.created_dag_run_id is not None + assert apdr_1.created_dag_run_id != apdr_2.created_dag_run_id + + dag_run_1 = session.scalar(select(DagRun).where(DagRun.id == apdr_1.created_dag_run_id)) + dag_run_2 = session.scalar(select(DagRun).where(DagRun.id == apdr_2.created_dag_run_id)) + assert dag_run_1 is not None + assert dag_run_2 is not None + assert dag_run_1.run_id != dag_run_2.run_id + assert len(dag_run_1.consumed_asset_events) == 1 + assert len(dag_run_2.consumed_asset_events) == 1 + + +@pytest.mark.need_serialized_dag +def test_non_partitioned_batch_asset_events_true_single_dagrun( + dag_maker: DagMaker, + session: Session, +): + """``batch_asset_events=True`` in non-partitioned path: one DagRun for all events. + + Multiple asset events for the same asset and Dag produce a single DagRun + that consumes all events. + """ + asset_1 = Asset(name="non-part-batch-true") + + # Consumer Dag with default AssetTriggeredTimetable (batch_asset_events=True). + with dag_maker( + dag_id="non-part-batch-true-consumer", + schedule=[asset_1], + session=session, + ): + EmptyOperator(task_id="task") + session.commit() + + dag_model = session.scalar(select(DagModel).where(DagModel.dag_id == "non-part-batch-true-consumer")) + assert dag_model is not None + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == asset_1.uri)) + assert asset_model is not None + + # Create two asset events with timestamps clearly before the ADRQ's created_at. + now = timezone.utcnow() + event_1 = AssetEvent( + asset_id=asset_model.id, + source_task_id="task", + source_dag_id="non-part-batch-true-consumer", + source_run_id="test-run", + source_map_index=-1, + timestamp=now - timedelta(minutes=5), + ) + event_2 = AssetEvent( + asset_id=asset_model.id, + source_task_id="task", + source_dag_id="non-part-batch-true-consumer", + source_run_id="test-run", + source_map_index=-1, + timestamp=now - timedelta(minutes=4), + ) + session.add_all([event_1, event_2]) + session.flush() + + # Queue an ADRQ for this Dag so the scheduler picks it up. + session.add(AssetDagRunQueue(asset_id=asset_model.id, target_dag_id="non-part-batch-true-consumer")) + session.flush() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + runner._create_dag_runs_asset_triggered( + dag_models=[dag_model], + session=session, + ) + + dag_runs = session.scalars(select(DagRun).where(DagRun.dag_id == "non-part-batch-true-consumer")).all() + assert len(dag_runs) == 1 + dag_run = dag_runs[0] + assert dag_run.run_type == DagRunType.ASSET_TRIGGERED + assert dag_run.state == DagRunState.QUEUED + assert len(dag_run.consumed_asset_events) == 2 + + # The ADRQ should have been cleaned up. + assert ( + session.scalar( + select(func.count()) + .select_from(AssetDagRunQueue) + .where(AssetDagRunQueue.target_dag_id == "non-part-batch-true-consumer") + ) + == 0 + ) + + +@pytest.mark.need_serialized_dag +def test_non_partitioned_batch_asset_events_false_one_dagrun_per_event( + dag_maker: DagMaker, + session: Session, +): + """``batch_asset_events=False`` in non-partitioned path: one DagRun per event. + + Multiple asset events for the same asset and Dag each produce their own + DagRun, each consuming exactly one event. + """ + asset_1 = Asset(name="non-part-batch-false") + + # Consumer Dag with batch_asset_events=False on the timetable. + with dag_maker( + dag_id="non-part-batch-false-consumer", + schedule=AssetTriggeredTimetable( + assets=asset_1, # type: ignore[arg-type] + batch_asset_events=False, + ), + session=session, + ): + EmptyOperator(task_id="task") + session.commit() + + dag_model = session.scalar(select(DagModel).where(DagModel.dag_id == "non-part-batch-false-consumer")) + assert dag_model is not None + asset_model = session.scalar(select(AssetModel).where(AssetModel.uri == asset_1.uri)) + assert asset_model is not None + + now = timezone.utcnow() + event_1 = AssetEvent( + asset_id=asset_model.id, + source_task_id="task", + source_dag_id="non-part-batch-false-consumer", + source_run_id="test-run", + source_map_index=-1, + timestamp=now - timedelta(minutes=5), + ) + event_2 = AssetEvent( + asset_id=asset_model.id, + source_task_id="task", + source_dag_id="non-part-batch-false-consumer", + source_run_id="test-run", + source_map_index=-1, + timestamp=now - timedelta(minutes=4), + ) + session.add_all([event_1, event_2]) + session.flush() + + session.add(AssetDagRunQueue(asset_id=asset_model.id, target_dag_id="non-part-batch-false-consumer")) + session.flush() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + runner._create_dag_runs_asset_triggered( + dag_models=[dag_model], + session=session, + ) + + dag_runs = session.scalars( + select(DagRun).where(DagRun.dag_id == "non-part-batch-false-consumer").order_by(DagRun.id) + ).all() + assert len(dag_runs) == 2 + for dag_run in dag_runs: + assert dag_run.run_type == DagRunType.ASSET_TRIGGERED + assert dag_run.state == DagRunState.QUEUED + assert len(dag_runs[0].consumed_asset_events) == 1 + assert len(dag_runs[1].consumed_asset_events) == 1 + assert dag_runs[0].run_id != dag_runs[1].run_id + + # ADRQ cleaned up. + assert ( + session.scalar( + select(func.count()) + .select_from(AssetDagRunQueue) + .where(AssetDagRunQueue.target_dag_id == "non-part-batch-false-consumer") + ) + == 0 + ) + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_partitioned_dag_run_with_invalid_mapping(dag_maker: DagMaker, session: Session): diff --git a/airflow-core/tests/unit/timetables/test_assets_timetable.py b/airflow-core/tests/unit/timetables/test_assets_timetable.py index 3c12c886f283f..a5d891cc0f6b1 100644 --- a/airflow-core/tests/unit/timetables/test_assets_timetable.py +++ b/airflow-core/tests/unit/timetables/test_assets_timetable.py @@ -268,6 +268,44 @@ def test_run_ordering_inheritance(core_asset_timetable) -> None: assert core_asset_timetable.run_ordering == AssetTriggeredTimetable.run_ordering +def test_asset_triggered_timetable_serialize(): + """AssetTriggeredTimetable.serialize includes batch_asset_events.""" + asset = SerializedAsset(name="test", uri="test://uri", group="asset", extra={}, watchers=[]) + timetable = AssetTriggeredTimetable(assets=asset) + serialized = timetable.serialize() + assert serialized["batch_asset_events"] is True + assert "asset_condition" in serialized + + +def test_asset_triggered_timetable_deserialize(): + """AssetTriggeredTimetable.deserialize recovers batch_asset_events.""" + asset = SerializedAsset(name="test", uri="test://uri", group="asset", extra={}, watchers=[]) + data = { + "asset_condition": { + "__type": "asset", + "name": "test", + "uri": "test://uri", + "group": "asset", + "extra": {}, + }, + "batch_asset_events": True, + } + timetable = AssetTriggeredTimetable.deserialize(data) + assert timetable.batch_asset_events is True + assert timetable.asset_condition == asset + + +def test_asset_triggered_timetable_batch_asset_events_false_roundtrip(): + """AssetTriggeredTimetable batch_asset_events=False survives serialize → deserialize.""" + asset = SerializedAsset(name="test", uri="test://uri", group="asset", extra={}, watchers=[]) + timetable = AssetTriggeredTimetable(assets=asset, batch_asset_events=False) + serialized = timetable.serialize() + assert serialized["batch_asset_events"] is False + + deserialized = AssetTriggeredTimetable.deserialize(serialized) + assert deserialized.batch_asset_events is False + + @pytest.mark.db_test class TestAssetConditionWithTimetable: @pytest.fixture(autouse=True) @@ -341,6 +379,7 @@ def test_dag_with_complex_asset_condition(self, dag_maker): serialized_timetable_dict = DagSerialization.to_dict(dag)["dag"]["timetable"]["__var"] assert serialized_timetable_dict == { + "batch_asset_events": True, "asset_condition": { "__type": "asset_any", "objects": [ diff --git a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py index 23caf06e79d0e..f6bb80674dd46 100644 --- a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py +++ b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py @@ -194,6 +194,7 @@ def test_serialize(self): assets=ser_asset, partition_mapper_config={ser_asset: IdentityMapper()} ) assert timetable.serialize() == { + "batch_asset_events": True, "asset_condition": { "__type": DagAttributeTypes.ASSET, "name": "test", @@ -225,6 +226,7 @@ def test_serialize(self): def test_deserialize(self): timetable = PartitionedAssetTimetable.deserialize( { + "batch_asset_events": True, "asset_condition": { "__type": DagAttributeTypes.ASSET, "name": "test", @@ -257,6 +259,21 @@ def test_deserialize(self): assert timetable.asset_condition == ser_asset assert isinstance(timetable.default_partition_mapper, IdentityMapper) assert isinstance(timetable.partition_mapper_config[ser_asset], IdentityMapper) + assert timetable.batch_asset_events is True + + def test_serialize_deserialize_batch_asset_events_false(self): + """Serialize/deserialize round-trip preserves batch_asset_events=False.""" + ser_asset = ensure_serialized_asset(Asset("test")) + timetable = PartitionedAssetTimetable( + assets=ser_asset, + partition_mapper_config={ser_asset: IdentityMapper()}, + batch_asset_events=False, + ) + serialized = timetable.serialize() + assert serialized["batch_asset_events"] is False + + deserialized = PartitionedAssetTimetable.deserialize(serialized) + assert deserialized.batch_asset_events is False def test_partitioned_asset_timetable_resolve_day_bound_returns_midnight_utc(self): """PartitionedAssetTimetable has no local timezone; resolve_day_bound uses the base default. diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py index 4bfc5090230fe..ea719a3cb4b78 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_utils.py @@ -50,7 +50,7 @@ from tests_common.test_utils.compat import ( BashOperator, ) -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS if AIRFLOW_V_3_1_PLUS: from airflow.models.dag import get_next_data_interval @@ -374,7 +374,7 @@ def test_serialize_timetable_complex_with_alias(): dag.timetable = AssetTriggeredTimetable(asset) dag_info = DagInfo(dag) - assert dag_info.timetable == { + expected: dict = { "asset_condition": { "__type": DagAttributeTypes.ASSET_ANY, "objects": [ @@ -415,38 +415,47 @@ def test_serialize_timetable_complex_with_alias(): ], }, ], - } + }, } + if AIRFLOW_V_3_3_PLUS: + expected["batch_asset_events"] = True + assert dag_info.timetable == expected @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") def test_serialize_timetable_single_asset(): dag = DAG(dag_id="test", start_date=datetime.datetime(2025, 1, 1), schedule=Asset("a")) dag_info = DagInfo(dag) - assert dag_info.timetable == { + expected: dict = { "asset_condition": { "__type": DagAttributeTypes.ASSET, "uri": "a", "name": "a", "group": "asset", "extra": {}, - } + }, } + if AIRFLOW_V_3_3_PLUS: + expected["batch_asset_events"] = True + assert dag_info.timetable == expected @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") def test_serialize_timetable_list_of_assets(): dag = DAG(dag_id="test", start_date=datetime.datetime(2025, 1, 1), schedule=[Asset("a"), Asset("b")]) dag_info = DagInfo(dag) - assert dag_info.timetable == { + expected: dict = { "asset_condition": { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ {"__type": DagAttributeTypes.ASSET, "uri": "a", "name": "a", "group": "asset", "extra": {}}, {"__type": DagAttributeTypes.ASSET, "uri": "b", "name": "b", "group": "asset", "extra": {}}, ], - } + }, } + if AIRFLOW_V_3_3_PLUS: + expected["batch_asset_events"] = True + assert dag_info.timetable == expected @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") @@ -458,7 +467,7 @@ def test_serialize_timetable_with_complex_logical_condition(): & (Asset("ds3") | Asset("ds4", extra={"another_extra": 345})), ) dag_info = DagInfo(dag) - assert dag_info.timetable == { + expected: dict = { "asset_condition": { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ @@ -501,8 +510,11 @@ def test_serialize_timetable_with_complex_logical_condition(): ], }, ], - } + }, } + if AIRFLOW_V_3_3_PLUS: + expected["batch_asset_events"] = True + assert dag_info.timetable == expected @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="This test checks serialization only in 3.0 conditions") diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 36fe5125a7f0f..85d3bc3e9c0c3 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -92,6 +92,7 @@ AIRFLOW_V_3_0_3_PLUS, AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, ) BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash" @@ -2469,7 +2470,7 @@ def test_dag_info_schedule_single_asset_directly(self): ) result = DagInfo(dag) - assert dict(result) == { + expected: dict = { "dag_id": "dag_id", "description": None, "fileloc": pathlib.Path(__file__).resolve().as_posix(), @@ -2484,10 +2485,13 @@ def test_dag_info_schedule_single_asset_directly(self): "name": "uri1", "group": "asset", "extra": {"a": 1}, - } + }, }, "timetable_summary": "Asset", } + if AIRFLOW_V_3_3_PLUS: + expected["timetable"]["batch_asset_events"] = True + assert dict(result) == expected def test_dag_info_schedule_list_single_assets(self): dag = DAG( @@ -2497,7 +2501,7 @@ def test_dag_info_schedule_list_single_assets(self): ) result = DagInfo(dag) - assert dict(result) == { + expected: dict = { "dag_id": "dag_id", "description": None, "fileloc": pathlib.Path(__file__).resolve().as_posix(), @@ -2517,10 +2521,13 @@ def test_dag_info_schedule_list_single_assets(self): "extra": {"a": 1}, } ], - } + }, }, "timetable_summary": "Asset", } + if AIRFLOW_V_3_3_PLUS: + expected["timetable"]["batch_asset_events"] = True + assert dict(result) == expected def test_dag_info_schedule_list_two_assets(self): dag = DAG( @@ -2530,7 +2537,7 @@ def test_dag_info_schedule_list_two_assets(self): ) result = DagInfo(dag) - assert dict(result) == { + expected: dict = { "dag_id": "dag_id", "description": None, "fileloc": pathlib.Path(__file__).resolve().as_posix(), @@ -2551,10 +2558,13 @@ def test_dag_info_schedule_list_two_assets(self): }, {"__type": "asset", "uri": "uri2", "name": "uri2", "group": "asset", "extra": {}}, ], - } + }, }, "timetable_summary": "Asset", } + if AIRFLOW_V_3_3_PLUS: + expected["timetable"]["batch_asset_events"] = True + assert dict(result) == expected def test_dag_info_schedule_assets_logical_condition(self): dag = DAG( @@ -2564,7 +2574,7 @@ def test_dag_info_schedule_assets_logical_condition(self): ) result = DagInfo(dag) - assert dict(result) == { + expected: dict = { "dag_id": "dag_id", "description": None, "fileloc": pathlib.Path(__file__).resolve().as_posix(), @@ -2615,10 +2625,13 @@ def test_dag_info_schedule_assets_logical_condition(self): ], }, ], - } + }, }, "timetable_summary": "Asset", } + if AIRFLOW_V_3_3_PLUS: + expected["timetable"]["batch_asset_events"] = True + assert dict(result) == expected def test_dag_info_schedule_asset_or_time_schedule(self): from airflow.timetables.assets import AssetOrTimeSchedule diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py index a0c6493692572..ec7170e9c4846 100644 --- a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py +++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py @@ -43,13 +43,13 @@ class AssetTriggeredTimetable(BaseTimetable): """ asset_condition: BaseAsset = attrs.field(alias="assets") + batch_asset_events: bool = True @attrs.define class PartitionedAssetTimetable(AssetTriggeredTimetable): """Asset-driven timetable that listens for partitioned assets.""" - asset_condition: BaseAsset = attrs.field(alias="assets") partition_mapper_config: dict[BaseAsset, PartitionMapper] = attrs.field(factory=dict) default_partition_mapper: PartitionMapper = IdentityMapper()