diff --git a/test/test_rb.py b/test/test_rb.py index 86d22d4ba32..850168fdef6 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4152,6 +4152,124 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): assert_allclose_td(rb_test[:], rb[:]) assert rb.writer._cursor == rb_test._writer._cursor + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_incremental_checkpointing(self, storage_type, tmpdir): + """Test incremental checkpointing saves only changed data.""" + from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer + + torch.manual_seed(0) + buffer_size = 100 + batch_size = 20 + + # Create buffer with incremental checkpointing enabled + rb = ReplayBuffer( + storage=storage_type(buffer_size), + batch_size=batch_size, + ) + rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True) + + # Create a second buffer to verify loads work correctly + rb_test = ReplayBuffer( + storage=storage_type(buffer_size), + batch_size=batch_size, + ) + rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True) + + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Add first batch and checkpoint + data1 = TensorDict( + { + "obs": torch.randn(batch_size, 4), + "action": torch.randint(0, 2, (batch_size,)), + }, + batch_size=[batch_size], + ) + rb.extend(data1) + rb.dumps(checkpoint_path) + + # Verify checkpoint cursor was set + assert rb.storage._last_checkpoint_cursor is not None + first_cursor = rb.storage._last_checkpoint_cursor + + # Load and verify + rb_test.loads(checkpoint_path) + assert_allclose_td(rb_test[:], rb[:]) + assert rb_test.storage._last_checkpoint_cursor == first_cursor + + # Add second batch and checkpoint (should be incremental) + data2 = TensorDict( + { + "obs": torch.randn(batch_size, 4), + "action": torch.randint(0, 2, (batch_size,)), + }, + batch_size=[batch_size], + ) + rb.extend(data2) + rb.dumps(checkpoint_path) + + # Verify cursor advanced + assert rb.storage._last_checkpoint_cursor > first_cursor + + # Load and verify + rb_test.loads(checkpoint_path) + assert_allclose_td(rb_test[:], rb[:]) + + # Add more data until buffer wraps around + for _ in range(5): + data = TensorDict( + { + "obs": torch.randn(batch_size, 4), + "action": torch.randint(0, 2, (batch_size,)), + }, + batch_size=[batch_size], + ) + rb.extend(data) + + # Checkpoint after wrap-around (should do full save) + rb.dumps(checkpoint_path) + + # Load and verify + rb_test.loads(checkpoint_path) + assert_allclose_td(rb_test[:], rb[:]) + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_incremental_checkpointing_no_changes(self, storage_type, tmpdir): + """Test incremental checkpoint when no data changed.""" + from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer + + torch.manual_seed(0) + buffer_size = 50 + batch_size = 10 + + rb = ReplayBuffer( + storage=storage_type(buffer_size), + batch_size=batch_size, + ) + rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True) + + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Add data and checkpoint + data = TensorDict( + {"obs": torch.randn(batch_size, 4)}, + batch_size=[batch_size], + ) + rb.extend(data) + rb.dumps(checkpoint_path) + + # Checkpoint again without adding data + rb.dumps(checkpoint_path) + + # Load and verify + rb_test = ReplayBuffer( + storage=storage_type(buffer_size), + batch_size=batch_size, + ) + rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True) + rb_test.loads(checkpoint_path) + assert_allclose_td(rb_test[:], rb[:]) + @pytest.mark.skipif(not _has_ray, reason="ray required for this test.") class TestRayRB: diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 437976b7b1e..d7d141527a5 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -331,8 +331,82 @@ class TensorStorageCheckpointer(StorageCheckpointerBase): This class will call save and load hooks if provided. These hooks should take as input the data being transformed as well as the path where the data should be saved. + This checkpointer supports incremental saves when checkpointing to the same path repeatedly. + Only the data that changed since the last checkpoint is written, significantly reducing + checkpoint time for large buffers. This is controlled by the ``incremental`` parameter. + + Keyword Args: + incremental (bool, optional): if ``True``, enables incremental checkpointing where only + modified data is saved when checkpointing to a path that already contains a checkpoint. + This can dramatically reduce checkpoint time for large buffers with frequent saves. + Defaults to ``False`` for backward compatibility. + """ + def __init__(self, *, incremental: bool = False): + super().__init__() + self.incremental = incremental + + def _compute_dirty_range( + self, last_checkpoint_cursor, current_cursor, max_size, is_full + ): + """Compute the range of indices that changed since the last checkpoint. + + Args: + last_checkpoint_cursor: Cursor position at the time of the last checkpoint. + current_cursor: Current cursor position (where next write will go). + max_size: Maximum size of the storage along dimension 0. + is_full: Whether the storage is completely full. + + Returns: + A tuple (start, end) representing the dirty range, or None if a full save is needed. + The range is [start, end) (end-exclusive). + """ + if last_checkpoint_cursor is None: + # First checkpoint, need full save + return None + + if current_cursor == last_checkpoint_cursor: + # Cursor hasn't moved. But if the buffer is full, it might have wrapped around + # completely (wrote exactly max_size elements since last checkpoint). + # In that case, we need a full save. + # If buffer is not full and cursor hasn't moved, no changes. + if is_full: + # Could have wrapped around completely - do full save to be safe + return None + # No changes since last checkpoint + return (0, 0) + + if current_cursor > last_checkpoint_cursor: + # Simple case: no wrap-around + return (last_checkpoint_cursor, current_cursor) + + # Wrap-around occurred: current_cursor < last_checkpoint_cursor + # This means we wrote from last_checkpoint_cursor to max_size, then from 0 to current_cursor + # For simplicity, we do a full save on wrap-around since it's complex to handle + # and wrap-around typically means most of the buffer changed anyway + return None + + def _save_incremental(self, storage, _storage, path, dirty_range): + """Save only the dirty range to existing memmap files. + + Args: + storage: The storage object. + _storage: The underlying tensor collection. + path: Path to the checkpoint directory. + dirty_range: Tuple (start, end) of indices to save. + """ + start, end = dirty_range + if start == end: + # No changes to save + return + + # Load existing memmap at path + existing = TensorDict.load_memmap(path) + + # Update only the dirty indices + existing[start:end].update_(_storage[start:end]) + def dumps(self, storage, path): path = Path(path) path.mkdir(exist_ok=True) @@ -344,6 +418,13 @@ def dumps(self, storage, path): self._set_hooks_shift_is_full(storage) + # Compute current cursor position for checkpoint tracking + last_cursor = storage._last_cursor + if last_cursor is not None: + current_checkpoint_cursor = self._get_shift_from_last_cursor(last_cursor) + else: + current_checkpoint_cursor = 0 + for hook in self._save_hooks: _storage = hook(_storage, path=path) if is_tensor_collection(_storage): @@ -353,8 +434,28 @@ def dumps(self, storage, path): and Path(_storage.saved_path).absolute() == Path(path).absolute() ): _storage.memmap_refresh_() + elif ( + self.incremental + and (path / "storage_metadata.json").exists() + and storage._last_checkpoint_cursor is not None + ): + # Incremental save: only save what changed + dirty_range = self._compute_dirty_range( + storage._last_checkpoint_cursor, + current_checkpoint_cursor, + storage._max_size_along_dim0(), + storage._is_full, + ) + if dirty_range is not None: + self._save_incremental(storage, _storage, path, dirty_range) + else: + # Wrap-around or other case requiring full save + _storage.memmap( + path, + copy_existing=True, + ) else: - # try to load the path and overwrite. + # Full save (first checkpoint or incremental disabled) _storage.memmap( path, copy_existing=True, # num_threads=torch.get_num_threads() @@ -364,12 +465,16 @@ def dumps(self, storage, path): _save_pytree(_storage, metadata, path) is_pytree = True + # Update the checkpoint cursor for next incremental save + storage._last_checkpoint_cursor = current_checkpoint_cursor + with open(path / "storage_metadata.json", "w") as file: json.dump( { "metadata": metadata, "is_pytree": is_pytree, "len": storage._len, + "last_checkpoint_cursor": current_checkpoint_cursor, }, file, ) @@ -453,6 +558,11 @@ def loads(self, storage, path): storage._storage.copy_(_storage) storage._len = _len + # Restore checkpoint cursor for incremental saves + last_checkpoint_cursor = metadata.get("last_checkpoint_cursor") + if last_checkpoint_cursor is not None: + storage._last_checkpoint_cursor = last_checkpoint_cursor + class FlatStorageCheckpointer(TensorStorageCheckpointer): """Saves the storage in a compact form, saving space on the TED format. @@ -539,6 +649,7 @@ def __init__( **kwargs, ): StorageCheckpointerBase.__init__(self) + self.incremental = False # H5 does not support incremental saves ted2_kwargs = kwargs if done_keys is not None: ted2_kwargs["done_keys"] = done_keys diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e6cdd64d583..05c6bc085a2 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -733,6 +733,7 @@ def __init__( ) self._storage = storage self._last_cursor = None + self._last_checkpoint_cursor = None self.__dict__["_storage_keys"] = None @property