diff --git a/.gitignore b/.gitignore index 41b950af36..a83a714b7f 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,6 @@ pufferlib/resources/drive/output_agent.gif pufferlib/resources/drive/output.gif artifacts/ # Local drive renders pufferlib/resources/drive/output*.mp4 + +# Local TODO tracking +TODO.md diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 68b2e1ef2a..abcd2d9468 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -302,6 +302,12 @@ def evaluate(self): total_agents = len(o) num_agents_per_env = total_agents // batch_size + # Convert global ego_ids to local (per-environment) indices + # ego_ids contains ALL ego agents across ALL environments, not just this batch + # Count how many ego IDs belong to the first environment (IDs < num_agents_per_env) + num_ego_per_env = sum(1 for eid in ego_ids if eid < num_agents_per_env) + local_ego_ids = [eid % num_agents_per_env for eid in ego_ids[:num_ego_per_env]] + original_shape = o.shape o = o.reshape(batch_size, num_agents_per_env, *original_shape[1:]) @@ -309,10 +315,10 @@ def evaluate(self): d = d.reshape(batch_size, num_agents_per_env) t = t.reshape(batch_size, num_agents_per_env) - o = o[:, ego_ids].reshape(batch_size * len(ego_ids), *original_shape[1:]) - r = r[:, ego_ids].flatten() - d = d[:, ego_ids].flatten() - t = t[:, ego_ids].flatten() + o = o[:, local_ego_ids].reshape(batch_size * num_ego_per_env, *original_shape[1:]) + r = r[:, local_ego_ids].flatten() + d = d[:, local_ego_ids].flatten() + t = t[:, local_ego_ids].flatten() else: o = o[ego_ids] r = r[ego_ids]