Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4152,6 +4152,124 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
assert_allclose_td(rb_test[:], rb[:])
assert rb.writer._cursor == rb_test._writer._cursor

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_incremental_checkpointing(self, storage_type, tmpdir):
"""Test incremental checkpointing saves only changed data."""
from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer

torch.manual_seed(0)
buffer_size = 100
batch_size = 20

# Create buffer with incremental checkpointing enabled
rb = ReplayBuffer(
storage=storage_type(buffer_size),
batch_size=batch_size,
)
rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True)

# Create a second buffer to verify loads work correctly
rb_test = ReplayBuffer(
storage=storage_type(buffer_size),
batch_size=batch_size,
)
rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True)

checkpoint_path = Path(tmpdir) / "checkpoint"

# Add first batch and checkpoint
data1 = TensorDict(
{
"obs": torch.randn(batch_size, 4),
"action": torch.randint(0, 2, (batch_size,)),
},
batch_size=[batch_size],
)
rb.extend(data1)
rb.dumps(checkpoint_path)

# Verify checkpoint cursor was set
assert rb.storage._last_checkpoint_cursor is not None
first_cursor = rb.storage._last_checkpoint_cursor

# Load and verify
rb_test.loads(checkpoint_path)
assert_allclose_td(rb_test[:], rb[:])
assert rb_test.storage._last_checkpoint_cursor == first_cursor

# Add second batch and checkpoint (should be incremental)
data2 = TensorDict(
{
"obs": torch.randn(batch_size, 4),
"action": torch.randint(0, 2, (batch_size,)),
},
batch_size=[batch_size],
)
rb.extend(data2)
rb.dumps(checkpoint_path)

# Verify cursor advanced
assert rb.storage._last_checkpoint_cursor > first_cursor

# Load and verify
rb_test.loads(checkpoint_path)
assert_allclose_td(rb_test[:], rb[:])

# Add more data until buffer wraps around
for _ in range(5):
data = TensorDict(
{
"obs": torch.randn(batch_size, 4),
"action": torch.randint(0, 2, (batch_size,)),
},
batch_size=[batch_size],
)
rb.extend(data)

# Checkpoint after wrap-around (should do full save)
rb.dumps(checkpoint_path)

# Load and verify
rb_test.loads(checkpoint_path)
assert_allclose_td(rb_test[:], rb[:])

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_incremental_checkpointing_no_changes(self, storage_type, tmpdir):
"""Test incremental checkpoint when no data changed."""
from torchrl.data.replay_buffers.checkpointers import TensorStorageCheckpointer

torch.manual_seed(0)
buffer_size = 50
batch_size = 10

rb = ReplayBuffer(
storage=storage_type(buffer_size),
batch_size=batch_size,
)
rb.storage.checkpointer = TensorStorageCheckpointer(incremental=True)

checkpoint_path = Path(tmpdir) / "checkpoint"

# Add data and checkpoint
data = TensorDict(
{"obs": torch.randn(batch_size, 4)},
batch_size=[batch_size],
)
rb.extend(data)
rb.dumps(checkpoint_path)

# Checkpoint again without adding data
rb.dumps(checkpoint_path)

# Load and verify
rb_test = ReplayBuffer(
storage=storage_type(buffer_size),
batch_size=batch_size,
)
rb_test.storage.checkpointer = TensorStorageCheckpointer(incremental=True)
rb_test.loads(checkpoint_path)
assert_allclose_td(rb_test[:], rb[:])


@pytest.mark.skipif(not _has_ray, reason="ray required for this test.")
class TestRayRB:
Expand Down
113 changes: 112 additions & 1 deletion torchrl/data/replay_buffers/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,82 @@ class TensorStorageCheckpointer(StorageCheckpointerBase):
This class will call save and load hooks if provided. These hooks should take as input the
data being transformed as well as the path where the data should be saved.

This checkpointer supports incremental saves when checkpointing to the same path repeatedly.
Only the data that changed since the last checkpoint is written, significantly reducing
checkpoint time for large buffers. This is controlled by the ``incremental`` parameter.

Keyword Args:
incremental (bool, optional): if ``True``, enables incremental checkpointing where only
modified data is saved when checkpointing to a path that already contains a checkpoint.
This can dramatically reduce checkpoint time for large buffers with frequent saves.
Defaults to ``False`` for backward compatibility.

"""

def __init__(self, *, incremental: bool = False):
super().__init__()
self.incremental = incremental

def _compute_dirty_range(
self, last_checkpoint_cursor, current_cursor, max_size, is_full
):
"""Compute the range of indices that changed since the last checkpoint.

Args:
last_checkpoint_cursor: Cursor position at the time of the last checkpoint.
current_cursor: Current cursor position (where next write will go).
max_size: Maximum size of the storage along dimension 0.
is_full: Whether the storage is completely full.

Returns:
A tuple (start, end) representing the dirty range, or None if a full save is needed.
The range is [start, end) (end-exclusive).
"""
if last_checkpoint_cursor is None:
# First checkpoint, need full save
return None

if current_cursor == last_checkpoint_cursor:
# Cursor hasn't moved. But if the buffer is full, it might have wrapped around
# completely (wrote exactly max_size elements since last checkpoint).
# In that case, we need a full save.
# If buffer is not full and cursor hasn't moved, no changes.
if is_full:
# Could have wrapped around completely - do full save to be safe
return None
# No changes since last checkpoint
return (0, 0)

if current_cursor > last_checkpoint_cursor:
# Simple case: no wrap-around
return (last_checkpoint_cursor, current_cursor)

# Wrap-around occurred: current_cursor < last_checkpoint_cursor
# This means we wrote from last_checkpoint_cursor to max_size, then from 0 to current_cursor
# For simplicity, we do a full save on wrap-around since it's complex to handle
# and wrap-around typically means most of the buffer changed anyway
return None

def _save_incremental(self, storage, _storage, path, dirty_range):
"""Save only the dirty range to existing memmap files.

Args:
storage: The storage object.
_storage: The underlying tensor collection.
path: Path to the checkpoint directory.
dirty_range: Tuple (start, end) of indices to save.
"""
start, end = dirty_range
if start == end:
# No changes to save
return

# Load existing memmap at path
existing = TensorDict.load_memmap(path)

# Update only the dirty indices
existing[start:end].update_(_storage[start:end])

def dumps(self, storage, path):
path = Path(path)
path.mkdir(exist_ok=True)
Expand All @@ -344,6 +418,13 @@ def dumps(self, storage, path):

self._set_hooks_shift_is_full(storage)

# Compute current cursor position for checkpoint tracking
last_cursor = storage._last_cursor
if last_cursor is not None:
current_checkpoint_cursor = self._get_shift_from_last_cursor(last_cursor)
else:
current_checkpoint_cursor = 0

for hook in self._save_hooks:
_storage = hook(_storage, path=path)
if is_tensor_collection(_storage):
Expand All @@ -353,8 +434,28 @@ def dumps(self, storage, path):
and Path(_storage.saved_path).absolute() == Path(path).absolute()
):
_storage.memmap_refresh_()
elif (
self.incremental
and (path / "storage_metadata.json").exists()
and storage._last_checkpoint_cursor is not None
):
# Incremental save: only save what changed
dirty_range = self._compute_dirty_range(
storage._last_checkpoint_cursor,
current_checkpoint_cursor,
storage._max_size_along_dim0(),
storage._is_full,
)
if dirty_range is not None:
self._save_incremental(storage, _storage, path, dirty_range)
else:
# Wrap-around or other case requiring full save
_storage.memmap(
path,
copy_existing=True,
)
else:
# try to load the path and overwrite.
# Full save (first checkpoint or incremental disabled)
_storage.memmap(
path,
copy_existing=True, # num_threads=torch.get_num_threads()
Expand All @@ -364,12 +465,16 @@ def dumps(self, storage, path):
_save_pytree(_storage, metadata, path)
is_pytree = True

# Update the checkpoint cursor for next incremental save
storage._last_checkpoint_cursor = current_checkpoint_cursor

with open(path / "storage_metadata.json", "w") as file:
json.dump(
{
"metadata": metadata,
"is_pytree": is_pytree,
"len": storage._len,
"last_checkpoint_cursor": current_checkpoint_cursor,
},
file,
)
Expand Down Expand Up @@ -453,6 +558,11 @@ def loads(self, storage, path):
storage._storage.copy_(_storage)
storage._len = _len

# Restore checkpoint cursor for incremental saves
last_checkpoint_cursor = metadata.get("last_checkpoint_cursor")
if last_checkpoint_cursor is not None:
storage._last_checkpoint_cursor = last_checkpoint_cursor


class FlatStorageCheckpointer(TensorStorageCheckpointer):
"""Saves the storage in a compact form, saving space on the TED format.
Expand Down Expand Up @@ -539,6 +649,7 @@ def __init__(
**kwargs,
):
StorageCheckpointerBase.__init__(self)
self.incremental = False # H5 does not support incremental saves
ted2_kwargs = kwargs
if done_keys is not None:
ted2_kwargs["done_keys"] = done_keys
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def __init__(
)
self._storage = storage
self._last_cursor = None
self._last_checkpoint_cursor = None
self.__dict__["_storage_keys"] = None

@property
Expand Down
Loading