Skip to content

Commit 01acef7

Browse files
author
Luca Carminati
committed
Revert None filtering logic and fix bugs
1 parent c428fd4 commit 01acef7

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

torchrl/collectors/collectors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,15 +3843,15 @@ def iterator(self) -> Iterator[TensorDictBase]:
38433843
# mask buffers if cat, and create a mask if stack
38443844
if cat_results != "stack":
38453845
buffers = [None] * self.num_workers
3846-
for idx, buffer in enumerate(filter(None.__ne__, self.buffers)):
3846+
for worker_idx, buffer in enumerate(filter(lambda x: x is not None, self.buffers)):
38473847
valid = buffer.get(("collector", "traj_ids")) != -1
38483848
if valid.ndim > 2:
38493849
valid = valid.flatten(0, -2)
38503850
if valid.ndim == 2:
38513851
valid = valid.any(0)
3852-
buffers[idx] = buffer[..., valid]
3852+
buffers[worker_idx] = buffer[..., valid]
38533853
else:
3854-
for buffer in filter(None.__ne__, self.buffers):
3854+
for buffer in filter(lambda x: x is not None, self.buffers):
38553855
with buffer.unlock_():
38563856
buffer.set(
38573857
("collector", "mask"),
@@ -3861,6 +3861,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
38613861
else:
38623862
buffers = self.buffers
38633863

3864+
# Skip frame counting if this worker didn't send data this iteration
3865+
# (happens when reusing buffers or on first iteration with some workers)
3866+
if self.buffers[idx] is None:
3867+
continue
3868+
38643869
workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
38653870

38663871
if workers_frames[idx] >= self.total_frames:
@@ -3892,7 +3897,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38923897
if same_device is None:
38933898
prev_device = None
38943899
same_device = True
3895-
for item in filter(None.__ne__, self.buffers):
3900+
for item in filter(lambda x: x is not None, self.buffers):
38963901
if prev_device is None:
38973902
prev_device = item.device
38983903
else:

0 commit comments

Comments
 (0)