@@ -3843,15 +3843,15 @@ def iterator(self) -> Iterator[TensorDictBase]:
38433843 # mask buffers if cat, and create a mask if stack
38443844 if cat_results != "stack" :
38453845 buffers = [None ] * self .num_workers
3846- for idx , buffer in enumerate (filter (None . __ne__ , self .buffers )):
3846+ for worker_idx , buffer in enumerate (filter (lambda x : x is not None , self .buffers )):
38473847 valid = buffer .get (("collector" , "traj_ids" )) != - 1
38483848 if valid .ndim > 2 :
38493849 valid = valid .flatten (0 , - 2 )
38503850 if valid .ndim == 2 :
38513851 valid = valid .any (0 )
3852- buffers [idx ] = buffer [..., valid ]
3852+ buffers [worker_idx ] = buffer [..., valid ]
38533853 else :
3854- for buffer in filter (None . __ne__ , self .buffers ):
3854+ for buffer in filter (lambda x : x is not None , self .buffers ):
38553855 with buffer .unlock_ ():
38563856 buffer .set (
38573857 ("collector" , "mask" ),
@@ -3861,6 +3861,11 @@ def iterator(self) -> Iterator[TensorDictBase]:
38613861 else :
38623862 buffers = self .buffers
38633863
3864+ # Skip frame counting if this worker didn't send data this iteration
3865+ # (happens when reusing buffers or on first iteration with some workers)
3866+ if self .buffers [idx ] is None :
3867+ continue
3868+
38643869 workers_frames [idx ] = workers_frames [idx ] + buffers [idx ].numel ()
38653870
38663871 if workers_frames [idx ] >= self .total_frames :
@@ -3892,7 +3897,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
38923897 if same_device is None :
38933898 prev_device = None
38943899 same_device = True
3895- for item in filter (None . __ne__ , self .buffers ):
3900+ for item in filter (lambda x : x is not None , self .buffers ):
38963901 if prev_device is None :
38973902 prev_device = item .device
38983903 else :
0 commit comments