diff --git a/mockfirestore/client.py b/mockfirestore/client.py index 75943bd..ebfc448 100644 --- a/mockfirestore/client.py +++ b/mockfirestore/client.py @@ -1,7 +1,7 @@ from typing import Iterable, Sequence from mockfirestore.collection import CollectionReference from mockfirestore.document import DocumentReference, DocumentSnapshot -from mockfirestore.transaction import Transaction +from mockfirestore.transaction import Transaction, Batch class MockFirestore: @@ -59,4 +59,7 @@ def get_all(self, references: Iterable[DocumentReference], def transaction(self, **kwargs) -> Transaction: return Transaction(self, **kwargs) + def batch(self) -> Batch: + return Batch(self) + diff --git a/mockfirestore/transaction.py b/mockfirestore/transaction.py index 7f06d2d..4d07c93 100644 --- a/mockfirestore/transaction.py +++ b/mockfirestore/transaction.py @@ -1,6 +1,6 @@ from functools import partial -import random from typing import Iterable, Callable + from mockfirestore._helpers import generate_random_string, Timestamp from mockfirestore.document import DocumentReference, DocumentSnapshot from mockfirestore.query import Query @@ -22,6 +22,7 @@ class Transaction: This mostly follows the model from https://googleapis.dev/python/firestore/latest/transaction.html """ + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): self._client = client @@ -117,3 +118,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: self.commit() + + +class Batch(Transaction): + def commit(self): + self._begin() # batch can call commit many times + super(Batch, self).commit()