diff --git a/test/test_collector.py b/test/test_collector.py index 1be9bc9ed15..548965909ea 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1712,6 +1712,186 @@ def env_fn(): total_frames=frames_per_batch * 100, ) + class FixedIDEnv(EnvBase): + """ + A simple mock environment that returns a fixed ID as its sole observation. + + This environment is designed to test MultiSyncDataCollector ordering. + Each environment instance is initialized with a unique env_id, which it + returns as the observation at every step. + """ + + def __init__( + self, + env_id: int, + max_steps: int = 10, + sleep_odd_only: bool = False, + **kwargs, + ): + """ + Args: + env_id: The ID to return as observation. This will be returned as a tensor. + max_steps: Maximum number of steps before the environment terminates. + """ + super().__init__(device="cpu", batch_size=torch.Size([])) + self.env_id = env_id + self.max_steps = max_steps + self.sleep_odd_only = sleep_odd_only + self._step_count = 0 + + # Define specs + self.observation_spec = Composite( + observation=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.action_spec = Composite( + action=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.reward_spec = Composite( + reward=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.done_spec = Composite( + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), + truncated=Unbounded(shape=(1,), dtype=torch.bool), + ) + + def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict: + """Reset the environment and return initial observation.""" + # Add sleep to simulate real-world timing variations + # This helps test that the collector properly handles different reset times + if not self.sleep_odd_only: + # Random sleep up to 10ms + time.sleep(torch.rand(1).item() * 0.01) + elif self.env_id % 2 == 1: + time.sleep(1) + + self._step_count = 0 + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "done": torch.tensor([False], dtype=torch.bool), + "terminated": torch.tensor([False], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _step(self, tensordict: TensorDict) -> TensorDict: + """Execute one step and return the env_id as observation.""" + self._step_count += 1 + done = self._step_count >= self.max_steps + + if self.sleep_odd_only and self.env_id % 2 == 1: + time.sleep(1) + + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "reward": torch.tensor([1.0], dtype=torch.float32), + "done": torch.tensor([done], dtype=torch.bool), + "terminated": torch.tensor([done], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed: int | None) -> int | None: + """Set the seed for reproducibility.""" + if seed is not None: + torch.manual_seed(seed) + return seed + + @pytest.mark.parametrize("num_envs,n_steps", [(8, 5)]) + @pytest.mark.parametrize("with_preempt", [False, True]) + @pytest.mark.parametrize("cat_results", ["stack", -1]) + def test_multi_sync_data_collector_ordering( + self, num_envs: int, n_steps: int, with_preempt: bool, cat_results: str | int + ): + """ + Test that MultiSyncDataCollector returns data in the correct order. + + We create num_envs environments, each returning its env_id as the observation. + After collection, we verify that the observations correspond to the correct env_ids in order + """ + if with_preempt and IS_OSX: + pytest.skip( + "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." + ) + + # Create environment factories using partial - one for each env_id + # This pattern mirrors CrossPlayEvaluator._rollout usage + env_factories = [ + functools.partial( + self.FixedIDEnv, env_id=i, max_steps=10, sleep_odd_only=with_preempt + ) + for i in range(num_envs) + ] + + # Initialize MultiSyncDataCollector + collector = MultiSyncDataCollector( + create_env_fn=env_factories, + frames_per_batch=num_envs * n_steps, + total_frames=num_envs * n_steps, + device="cpu", + preemptive_threshold=0.5 if with_preempt else None, + cat_results=cat_results, + init_random_frames=n_steps, # no need of a policy + use_buffers=True, + ) + + # Collect one batch + for batch in collector: + # Verify that each environment's observations match its env_id + # batch has shape [num_envs, frames_per_env] + # In the pre-emption case, we have that envs with odd ids are order of magnitude slower. + # These should be skipped by pre-emption (since they are the 50% slowest) + + # Recover rectangular shape of batch to uniform checks + if cat_results != "stack": + if not with_preempt: + batch = batch.reshape(num_envs, n_steps) + else: + traj_ids = batch["collector", "traj_ids"] + traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0 + # Split trajectories to recover correct shape + # thanks to having a single trajectory per env + # Pads with zeros! + batch = split_trajectories( + batch, trajectory_key=("collector", "traj_ids") + ) + # Use -1 for padding to uniform with other preemption + is_padded = batch["collector", "traj_ids"] == 0 + batch[is_padded] = -1 + + # + for env_idx in range(num_envs): + if with_preempt and env_idx % 2 == 1: + # This is a slow env, should have been preempted after first step + assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all() + continue + # This is a fast env, no preemption happened + assert (batch["collector", "traj_ids"][env_idx] != -1).all() + + env_data = batch[env_idx] + observations = env_data["observation"] + # All observations from this environment should equal its env_id + expected_id = float(env_idx) + actual_ids = observations.flatten().unique() + + assert len(actual_ids) == 1, ( + f"Env {env_idx} should only produce observations with value {expected_id}, " + f"but got {actual_ids.tolist()}" + ) + assert ( + actual_ids[0].item() == expected_id + ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" + + collector.shutdown() + class TestCollectorDevices: class DeviceLessEnv(EnvBase): diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py index 1f756a8b26d..c423f1f02c6 100644 --- a/torchrl/collectors/_multi_sync.py +++ b/torchrl/collectors/_multi_sync.py @@ -216,7 +216,7 @@ def iterator(self) -> Iterator[TensorDictBase]: if cat_results is None: cat_results = "stack" - self.buffers = {} + self.buffers = [None for _ in range(self.num_workers)] dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] same_device = None @@ -236,7 +236,6 @@ def iterator(self) -> Iterator[TensorDictBase]: msg = "continue_random" else: msg = "continue" - # Debug: sending 'continue' self.pipes[idx].send((None, msg)) self._iter += 1 @@ -299,8 +298,11 @@ def iterator(self) -> Iterator[TensorDictBase]: if preempt: # mask buffers if cat, and create a mask if stack if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): + buffers = [None] * self.num_workers + for worker_idx, buffer in enumerate(self.buffers): + # Skip pre-empted envs: + if buffer is None: + continue valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) @@ -308,7 +310,7 @@ def iterator(self) -> Iterator[TensorDictBase]: valid = valid.any(0) buffers[worker_idx] = buffer[..., valid] else: - for buffer in self.buffers.values(): + for buffer in filter(lambda x: x is not None, self.buffers): with buffer.unlock_(): buffer.set( ("collector", "mask"), @@ -320,7 +322,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # Skip frame counting if this worker didn't send data this iteration # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: + if self.buffers[idx] is None: continue workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() @@ -331,18 +333,18 @@ def iterator(self) -> Iterator[TensorDictBase]: if self.replay_buffer is not None: yield self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx=worker_idx) - for worker_idx in range(self.num_workers) - ] + self.frames_per_batch_worker(worker_idx) + for worker_idx in range(self.num_workers) ) continue # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx in buffers.keys(): + for idx in range(self.num_workers): buffer = buffers[idx] + if buffer is None: + continue traj_ids = buffer.get(("collector", "traj_ids")) if preempt: if cat_results == "stack": @@ -356,7 +358,7 @@ def iterator(self) -> Iterator[TensorDictBase]: if same_device is None: prev_device = None same_device = True - for item in self.buffers.values(): + for item in filter(lambda x: x is not None, self.buffers): if prev_device is None: prev_device = item.device else: @@ -367,10 +369,12 @@ def iterator(self) -> Iterator[TensorDictBase]: torch.stack if self._use_buffers else TensorDict.maybe_dense_stack ) if same_device: - self.out_buffer = stack(list(buffers.values()), 0) + self.out_buffer = stack( + [item for item in buffers if item is not None], 0 + ) else: self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 + [item.cpu() for item in buffers if item is not None], 0 ) else: if self._use_buffers is None: @@ -383,10 +387,13 @@ def iterator(self) -> Iterator[TensorDictBase]: ) try: if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) + self.out_buffer = torch.cat( + [item for item in buffers if item is not None], cat_results + ) else: self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results + [item.cpu() for item in buffers if item is not None], + cat_results, ) except RuntimeError as err: if (