Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions examples/3_replay_neuracore_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from pathlib import Path
from typing import cast

import cv2
import neuracore as nc
import numpy as np
from common.configs import (
CAMERA_LOGGING_NAME,
GRIPPER_LOGGING_NAME,
JOINT_NAMES,
NEUTRAL_JOINT_ANGLES,
Expand Down Expand Up @@ -117,7 +117,7 @@ def main() -> None:
episode = synced_dataset[episode_idx]

print(f"\n🚀 Collecting episode {episode_idx} data...")
rgb_images = []
rgb_frames_per_step: list[dict[str, np.ndarray]] = []
parallel_gripper_open_amounts = []
joint_positions = []
for step in tqdm(episode, desc=f"Collecting episode {episode_idx}"):
Expand All @@ -144,11 +144,13 @@ def main() -> None:
gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount
parallel_gripper_open_amounts.append(gripper_value)

# Extract RGB image (just store first one for compatibility)
# Extract RGB for all cameras
step_frames: dict[str, np.ndarray] = {}
if DataType.RGB_IMAGES in step.data:
rgb_data = step.data[DataType.RGB_IMAGES]
if CAMERA_LOGGING_NAME in rgb_data:
rgb_images.append(rgb_data[CAMERA_LOGGING_NAME].frame)
for camera_name, img_value in rgb_data.items():
step_frames[camera_name] = img_value.frame
rgb_frames_per_step.append(step_frames)

joint_positions = np.degrees(np.array(joint_positions))
parallel_gripper_open_amounts = np.array(parallel_gripper_open_amounts)
Expand All @@ -162,8 +164,20 @@ def main() -> None:
robot_controller.set_gripper_open_value(
parallel_gripper_open_amounts[index]
)

# Display camera frames (dataset stores RGB; OpenCV expects BGR)
if index < len(rgb_frames_per_step):
for camera_name, frame_rgb in rgb_frames_per_step[index].items():
arr = np.asarray(frame_rgb, dtype=np.uint8)
frame_bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
cv2.imshow(f"Replay: {camera_name}", frame_bgr)
if cv2.waitKey(1) & 0xFF == ord("q"):
print("\n🛑 'q' pressed, stopping replay...")
break

end_time = time.time()
time.sleep(max(0, 1 / args.frequency - (end_time - start_time)))
cv2.destroyAllWindows()
print(f"🎉 Episode {episode_idx} replay completed.")

if args.episode_index == -1:
Expand All @@ -172,6 +186,7 @@ def main() -> None:
print(f"{'='*60}")
except KeyboardInterrupt:
print("\n🛑 Keyboard interrupt detected, stopping robot control loop...")
cv2.destroyAllWindows()

robot_controller.stop_control_loop()
robot_controller.cleanup()
Expand Down