diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index 29f4febd0b5..fffb220a010 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -804,6 +804,8 @@ def _update_action_mask(self, td, observation_dict, info_dict): group_mask = td.get((group, "action_mask")) group_mask += True for index, agent in enumerate(agents): + if agent not in observation_dict or agent not in info_dict: + continue agent_obs = observation_dict[agent] agent_info = info_dict[agent] if isinstance(agent_obs, dict) and "action_mask" in agent_obs: