diff --git a/ingestify/application/dataset_store.py b/ingestify/application/dataset_store.py index 738cfb5..8aa0156 100644 --- a/ingestify/application/dataset_store.py +++ b/ingestify/application/dataset_store.py @@ -497,15 +497,19 @@ def update_dataset( def invalidate_revision(self, dataset: Dataset, reason: str = ""): """Mark the current revision as VALIDATION_FAILED and reset - last_modified_at so the dataset is refetched on the next run. - - Args: - dataset: Dataset whose current revision should be invalidated - reason: Human-readable reason for invalidation - """ - self.dataset_repository.invalidate_revision(dataset) - - self.dispatch(RevisionInvalidated(dataset=dataset, reason=reason)) + last_modified_at so the dataset is refetched on the next run.""" + self.invalidate_revisions([dataset], reason=reason) + + def invalidate_revisions(self, datasets: list, reason: str = ""): + """Batch invalidate revisions. Batches DB updates and event writes + per 1000 datasets for efficiency.""" + batch_size = 1000 + for i in range(0, len(datasets), batch_size): + batch = datasets[i : i + batch_size] + self.dataset_repository.invalidate_revisions(batch) + self.event_bus.dispatch_many( + [RevisionInvalidated(dataset=ds, reason=reason) for ds in batch] + ) def destroy_dataset(self, dataset: Dataset): # TODO: remove files. Now we leave some orphaned files around diff --git a/ingestify/domain/models/dataset/dataset_repository.py b/ingestify/domain/models/dataset/dataset_repository.py index 2079812..eb39cdd 100644 --- a/ingestify/domain/models/dataset/dataset_repository.py +++ b/ingestify/domain/models/dataset/dataset_repository.py @@ -36,10 +36,15 @@ def get_dataset_last_modified_at_map( dataset+revision+file graph.""" return {} - @abstractmethod def invalidate_revision(self, dataset: Dataset): """Mark the current revision as VALIDATION_FAILED and reset last_modified_at on the dataset.""" + self.invalidate_revisions([dataset]) + + @abstractmethod + def invalidate_revisions(self, datasets: list[Dataset]): + """Batch invalidate: mark current revisions as VALIDATION_FAILED + and reset last_modified_at on the datasets.""" pass @abstractmethod diff --git a/ingestify/domain/models/event/dispatcher.py b/ingestify/domain/models/event/dispatcher.py index 3cabcac..ae8bc8d 100644 --- a/ingestify/domain/models/event/dispatcher.py +++ b/ingestify/domain/models/event/dispatcher.py @@ -6,3 +6,6 @@ class Dispatcher(Protocol): def dispatch(self, event: DomainEvent): pass + + def dispatch_many(self, events: list[DomainEvent]): + pass diff --git a/ingestify/domain/models/event/event_bus.py b/ingestify/domain/models/event/event_bus.py index cc345e0..6604c70 100644 --- a/ingestify/domain/models/event/event_bus.py +++ b/ingestify/domain/models/event/event_bus.py @@ -14,6 +14,10 @@ def __init__(self, queue): def dispatch(self, event): self.queue.put(event) + def dispatch_many(self, events): + for event in events: + self.queue.put(event) + class EventBus: def __init__(self): @@ -37,3 +41,11 @@ def dispatch(self, event): except Exception as e: logger.exception(f"Failed to handle {event}") raise Exception(f"Failed to handle {event}") from e + + def dispatch_many(self, events): + for dispatcher in self.dispatchers: + try: + dispatcher.dispatch_many(events) + except Exception as e: + logger.exception(f"Failed to handle {len(events)} events") + raise Exception(f"Failed to handle {len(events)} events") from e diff --git a/ingestify/domain/models/event/publisher.py b/ingestify/domain/models/event/publisher.py index 557634e..7e075fa 100644 --- a/ingestify/domain/models/event/publisher.py +++ b/ingestify/domain/models/event/publisher.py @@ -19,5 +19,14 @@ def dispatch(self, event: DomainEvent): except Exception: logger.exception(f"Failed to handle {event} by {subscriber}") + def dispatch_many(self, events: list[DomainEvent]): + for subscriber in self.subscribers: + try: + subscriber.handle_many(events) + except Exception: + logger.exception( + f"Failed to handle {len(events)} events by {subscriber}" + ) + def add_subscriber(self, subscriber: Subscriber): self.subscribers.append(subscriber) diff --git a/ingestify/domain/models/event/subscriber.py b/ingestify/domain/models/event/subscriber.py index 5406015..70350b5 100644 --- a/ingestify/domain/models/event/subscriber.py +++ b/ingestify/domain/models/event/subscriber.py @@ -44,3 +44,8 @@ def handle(self, event: DomainEvent): self.on_revision_added(event) elif isinstance(event, RevisionInvalidated): self.on_revision_invalidated(event) + + def handle_many(self, events: list[DomainEvent]): + """Handle a batch of events. Override for efficient bulk writes.""" + for event in events: + self.handle(event) diff --git a/ingestify/infra/event_log/event_log.py b/ingestify/infra/event_log/event_log.py index a2be564..7cf7350 100644 --- a/ingestify/infra/event_log/event_log.py +++ b/ingestify/infra/event_log/event_log.py @@ -31,16 +31,24 @@ def __init__(self, engine, table_prefix: str = ""): self._table.create(engine, checkfirst=True) def write(self, event: DomainEvent) -> None: + self.write_many([event]) + + def write_many(self, events: list[DomainEvent]) -> None: + if not events: + return + now = utcnow() + rows = [ + { + "event_type": type(event).event_type, + "payload_json": event.model_dump(mode="json"), + "source": event.dataset.provider, + "dataset_id": event.dataset.dataset_id, + "created_at": now, + } + for event in events + ] with self._engine.connect() as conn: - conn.execute( - self._table.insert().values( - event_type=type(event).event_type, - payload_json=event.model_dump(mode="json"), - source=event.dataset.provider, - dataset_id=event.dataset.dataset_id, - created_at=utcnow(), - ) - ) + conn.execute(self._table.insert(), rows) conn.commit() def fetch_batch(self, last_event_id: int, batch_size: int) -> list: diff --git a/ingestify/infra/event_log/subscriber.py b/ingestify/infra/event_log/subscriber.py index eac38b5..7b371fb 100644 --- a/ingestify/infra/event_log/subscriber.py +++ b/ingestify/infra/event_log/subscriber.py @@ -34,6 +34,15 @@ def _write(self, event) -> None: event.dataset.dataset_id, ) + def _write_many(self, events) -> None: + try: + self._event_log.write_many(events) + except Exception: + logger.exception( + "EventLogSubscriber: failed to write %d events", + len(events), + ) + def on_dataset_created(self, event) -> None: self._write(event) @@ -45,3 +54,6 @@ def on_revision_added(self, event) -> None: def on_revision_invalidated(self, event) -> None: self._write(event) + + def handle_many(self, events) -> None: + self._write_many(events) diff --git a/ingestify/infra/store/dataset/sqlalchemy/repository.py b/ingestify/infra/store/dataset/sqlalchemy/repository.py index 4b9b14c..e58ddb0 100644 --- a/ingestify/infra/store/dataset/sqlalchemy/repository.py +++ b/ingestify/infra/store/dataset/sqlalchemy/repository.py @@ -677,28 +677,34 @@ def _save(self, datasets: list[Dataset]): connection.commit() def invalidate_revision(self, dataset: Dataset): - current_revision = dataset.current_revision + self.invalidate_revisions([dataset]) + + def invalidate_revisions(self, datasets: list[Dataset]): + if not datasets: + return + + dataset_ids = [d.dataset_id for d in datasets] + with self.connect() as connection: - # Set revision state to VALIDATION_FAILED + # Batch update revision state connection.execute( self.revision_table.update() - .where(self.revision_table.c.dataset_id == dataset.dataset_id) - .where( - self.revision_table.c.revision_id == current_revision.revision_id - ) + .where(self.revision_table.c.dataset_id.in_(dataset_ids)) .values(state=RevisionState.VALIDATION_FAILED) ) - # Reset last_modified_at so the pre-check cache doesn't skip it + # Batch reset last_modified_at connection.execute( self.dataset_table.update() - .where(self.dataset_table.c.dataset_id == dataset.dataset_id) + .where(self.dataset_table.c.dataset_id.in_(dataset_ids)) .values(last_modified_at=None) ) connection.commit() # Update in-memory state - current_revision.state = RevisionState.VALIDATION_FAILED - dataset.last_modified_at = None + for dataset in datasets: + if dataset.current_revision: + dataset.current_revision.state = RevisionState.VALIDATION_FAILED + dataset.last_modified_at = None def destroy(self, dataset: Dataset): with self.connect() as connection: diff --git a/ingestify/tests/test_refetch_validation_failed.py b/ingestify/tests/test_refetch_validation_failed.py index cb6eaf4..a962d01 100644 --- a/ingestify/tests/test_refetch_validation_failed.py +++ b/ingestify/tests/test_refetch_validation_failed.py @@ -25,22 +25,27 @@ def counting_loader(file_resource, current_file, **kwargs): class SimpleSource(Source): provider = "test_provider" + def __init__(self, name, n_datasets=1): + super().__init__(name) + self.n_datasets = n_datasets + def find_datasets( self, dataset_type, data_spec_versions, dataset_collection_metadata, **kwargs ): - r = DatasetResource( - dataset_resource_id={"item_id": 1}, - provider=self.provider, - dataset_type="test", - name="item-1", - ) - r.add_file( - last_modified=FIXED_TIME, - data_feed_key="f1", - data_spec_version="v1", - file_loader=counting_loader, - ) - yield r + for i in range(self.n_datasets): + r = DatasetResource( + dataset_resource_id={"item_id": i}, + provider=self.provider, + dataset_type="test", + name=f"item-{i}", + ) + r.add_file( + last_modified=FIXED_TIME, + data_feed_key="f1", + data_spec_version="v1", + file_loader=counting_loader, + ) + yield r def _setup(engine): @@ -99,3 +104,47 @@ def test_invalidate_revision_triggers_refetch(engine): # Second run: should refetch engine.run() assert call_count == 2, "Dataset with invalidated revision should be refetched" + + +def test_invalidate_revisions_batch(engine): + """invalidate_revisions works on multiple datasets at once.""" + global call_count + call_count = 0 + + dsv = DataSpecVersionCollection.from_dict({"default": {"v1"}}) + engine.add_ingestion_plan( + IngestionPlan( + source=SimpleSource("s", n_datasets=5), + fetch_policy=FetchPolicy(), + dataset_type="test", + selectors=[Selector.build({}, data_spec_versions=dsv)], + data_spec_versions=dsv, + ) + ) + + # First run: creates 5 datasets + engine.run() + assert call_count == 5 + + # Batch invalidate all 5 + datasets = list( + engine.store.get_dataset_collection( + provider="test_provider", dataset_type="test" + ) + ) + assert len(datasets) == 5 + engine.store.invalidate_revisions(datasets, reason="Batch test") + + # Verify all invalidated + datasets = list( + engine.store.get_dataset_collection( + provider="test_provider", dataset_type="test" + ) + ) + for ds in datasets: + assert ds.current_revision.state == RevisionState.VALIDATION_FAILED + assert ds.last_modified_at is None + + # Second run: should refetch all 5 + engine.run() + assert call_count == 10, "All 5 invalidated datasets should be refetched"