diff --git a/pufferlib/bitworld_pufferlib.py b/pufferlib/bitworld_pufferlib.py index cd47e18a..e39b14ac 100644 --- a/pufferlib/bitworld_pufferlib.py +++ b/pufferlib/bitworld_pufferlib.py @@ -1150,12 +1150,15 @@ def reset(self) -> np.ndarray: self.episode_steps = 0 return frame - def step(self, action_mask: int) -> tuple[np.ndarray, float]: + def step(self, action_masks: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + masks = np.asarray(action_masks, dtype=np.uint8) + if masks.shape != (self.agent_count,): + raise ValueError(f"expected {self.agent_count} BitWorld action masks") assert self.connection is not None with self._condition: start_seq = self._frame_seq start_reward_seq = self._reward_seq - self.connection.send(bytes([action_mask]), text=False) + self.connection.send(bytes([int(masks[0])]), text=False) frame, _ = self._wait_for_frame( lambda _item, seq: seq >= start_seq + self.action_repeat ) @@ -1166,7 +1169,9 @@ def step(self, action_mask: int) -> tuple[np.ndarray, float]: self.score = snapshot self.episode_return += reward_delta self.episode_steps += 1 - return frame, reward_delta + frames = np.asarray(frame)[np.newaxis] + rewards = np.asarray([reward_delta], dtype=np.float32) + return frames, rewards def close(self) -> None: with self._condition: @@ -1436,13 +1441,8 @@ def _step_env(self, env_id: int, action_indices: np.ndarray): agent_slice = self._agent_slice(env_id) clipped = np.clip(action_indices[agent_slice], 0, self.action_count - 1).astype(np.int64) action_masks = ACTION_MASKS[clipped] - if worker.agent_count == 1: - frame, reward = worker.step(int(action_masks[0])) - frames = self._frame_batch(frame, worker) - rewards = np.asarray([reward], dtype=np.float32) - else: - frames, rewards = worker.step(action_masks) - frames = self._frame_batch(frames, worker) + frames, rewards = worker.step(action_masks) + frames = self._frame_batch(frames, worker) completed: list[EpisodeStats] = [] if isinstance(worker, AmongThemNativeWorker):