diff --git a/examples/3_replay_neuracore_episodes.py b/examples/3_replay_neuracore_episodes.py index 86a6b5c..fc06f26 100644 --- a/examples/3_replay_neuracore_episodes.py +++ b/examples/3_replay_neuracore_episodes.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import cast +import cv2 import neuracore as nc import numpy as np from common.configs import ( - CAMERA_LOGGING_NAME, GRIPPER_LOGGING_NAME, JOINT_NAMES, NEUTRAL_JOINT_ANGLES, @@ -117,7 +117,7 @@ def main() -> None: episode = synced_dataset[episode_idx] print(f"\nšŸš€ Collecting episode {episode_idx} data...") - rgb_images = [] + rgb_frames_per_step: list[dict[str, np.ndarray]] = [] parallel_gripper_open_amounts = [] joint_positions = [] for step in tqdm(episode, desc=f"Collecting episode {episode_idx}"): @@ -144,11 +144,13 @@ def main() -> None: gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount parallel_gripper_open_amounts.append(gripper_value) - # Extract RGB image (just store first one for compatibility) + # Extract RGB for all cameras + step_frames: dict[str, np.ndarray] = {} if DataType.RGB_IMAGES in step.data: rgb_data = step.data[DataType.RGB_IMAGES] - if CAMERA_LOGGING_NAME in rgb_data: - rgb_images.append(rgb_data[CAMERA_LOGGING_NAME].frame) + for camera_name, img_value in rgb_data.items(): + step_frames[camera_name] = img_value.frame + rgb_frames_per_step.append(step_frames) joint_positions = np.degrees(np.array(joint_positions)) parallel_gripper_open_amounts = np.array(parallel_gripper_open_amounts) @@ -162,8 +164,20 @@ def main() -> None: robot_controller.set_gripper_open_value( parallel_gripper_open_amounts[index] ) + + # Display camera frames (dataset stores RGB; OpenCV expects BGR) + if index < len(rgb_frames_per_step): + for camera_name, frame_rgb in rgb_frames_per_step[index].items(): + arr = np.asarray(frame_rgb, dtype=np.uint8) + frame_bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) + cv2.imshow(f"Replay: {camera_name}", frame_bgr) + if cv2.waitKey(1) & 0xFF == ord("q"): + print("\nšŸ›‘ 'q' pressed, stopping replay...") + break + end_time = time.time() time.sleep(max(0, 1 / args.frequency - (end_time - start_time))) + cv2.destroyAllWindows() print(f"šŸŽ‰ Episode {episode_idx} replay completed.") if args.episode_index == -1: @@ -172,6 +186,7 @@ def main() -> None: print(f"{'='*60}") except KeyboardInterrupt: print("\nšŸ›‘ Keyboard interrupt detected, stopping robot control loop...") + cv2.destroyAllWindows() robot_controller.stop_control_loop() robot_controller.cleanup()