Skip to content

Commit a5db90e

Browse files
authored
Merge pull request #11 from blankey1337/feature/nvidia-sim-integration
additional sim improvements, dialing in the workflow for remote training
2 parents 969d6ad + ef2dacb commit a5db90e

4 files changed

Lines changed: 295 additions & 4 deletions

File tree

docs/remote_simulation_workflow.md

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Remote Simulation & Training Workflow
2+
3+
This guide outlines how to develop, train, and test AlohaMini using a remote cloud GPU (e.g., Lambda Labs, AWS) for the heavy simulation, while controlling everything from your local laptop (e.g., MacBook).
4+
5+
## Architecture
6+
7+
* **Cloud Server (The "Lab")**: Runs NVIDIA Isaac Sim. Handles physics, rendering, and training.
8+
* **Local Machine (The "Mission Control")**: Runs the Dashboard and Teleoperation scripts. Connects to the cloud via SSH.
9+
10+
## Prerequisites
11+
12+
1. **Cloud Instance**: A server with an NVIDIA RTX GPU (A10, A100, RTX 3090/4090).
13+
* Recommended: Lambda Labs or Brev.dev (Ubuntu 20.04/22.04).
14+
* Must have **NVIDIA Drivers** and **Isaac Sim** installed (or use the Isaac Sim Docker container).
15+
2. **Local Machine**: Your laptop (Mac/Windows/Linux).
16+
3. **SSH Access**: You must be able to SSH into the cloud instance.
17+
18+
## Setup
19+
20+
### 1. Cloud Server Setup
21+
1. SSH into your cloud instance.
22+
2. Clone this repository:
23+
```bash
24+
git clone https://github.com/blankey1337/AlohaMini.git
25+
cd AlohaMini
26+
```
27+
3. Ensure you are in the python environment that has access to Isaac Sim (often `./python.sh` in the Isaac Sim folder).
28+
29+
### 2. Local Machine Setup
30+
1. Clone this repository locally.
31+
2. Install dependencies:
32+
```bash
33+
pip install -r software/requirements.txt
34+
```
35+
36+
## The Workflow
37+
38+
### Phase 1: Data Collection
39+
40+
1. **Start the Simulation (Cloud)**
41+
Run the simulation environment script. This listens on ports 5555 (Cmd) and 5556 (Obs).
42+
```bash
43+
# On Cloud
44+
isaac_sim_python software/examples/alohamini/isaac_sim/isaac_alohamini_env.py
45+
```
46+
47+
2. **Establish Connection (Local)**
48+
Forward the ZMQ ports from the cloud to your localhost.
49+
```bash
50+
# On Local Mac
51+
ssh -L 5555:localhost:5555 -L 5556:localhost:5556 ubuntu@<CLOUD_IP>
52+
```
53+
54+
3. **Launch Dashboard (Local)**
55+
Start the web dashboard to see what the robot sees.
56+
```bash
57+
# On Local Mac
58+
python software/dashboard/app.py
59+
```
60+
Open `http://localhost:5001` in your browser.
61+
62+
4. **Teleoperate & Record**
63+
* Use the Dashboard to see the camera feed.
64+
* Run the teleop script in another terminal to control the robot with your keyboard:
65+
```bash
66+
python software/examples/alohamini/standalone_teleop.py --ip 127.0.0.1
67+
```
68+
* **To Record**: Click the **"Start Recording"** button on the Dashboard.
69+
* Perform the task (e.g., pick up the object).
70+
* Click **"Stop Recording"**.
71+
* Repeat 50-100 times. The data is saved to `AlohaMini/data_sim/` on the **Cloud Server**.
72+
73+
### Phase 2: Training
74+
75+
You train the model directly on the Cloud GPU where the data lives.
76+
77+
1. **Stop the Simulation** (to free up GPU VRAM).
78+
2. **Run Training**:
79+
Use the LeRobot training script (or your custom training script) pointing to the generated dataset.
80+
```bash
81+
# On Cloud
82+
python software/src/lerobot/scripts/train.py \
83+
--dataset data_sim \
84+
--policy act \
85+
--batch_size 8 \
86+
--num_epochs 1000
87+
```
88+
*Note: Exact training command depends on the LeRobot configuration.*
89+
90+
3. **Output**: This produces a model file (e.g., `outputs/train/policy.safetensors`).
91+
92+
### Phase 3: Evaluation
93+
94+
Test the trained model in the simulator to see if it works.
95+
96+
1. **Restart Simulation (Cloud)**:
97+
```bash
98+
isaac_sim_python software/examples/alohamini/isaac_sim/isaac_alohamini_env.py
99+
```
100+
2. **Run Inference Node (Cloud or Local)**:
101+
You need a script that loads the model and closes the loop (reads obs -> runs model -> sends action).
102+
* *Coming Soon: `eval_sim.py` which loads the safetensor and drives the ZMQ robot.*
103+
104+
3. **Watch (Local)**:
105+
Use the Dashboard to watch the robot perform the task autonomously.
106+
107+
## Troubleshooting
108+
109+
* **Laggy Video**: ZMQ over SSH tunneling is usually fast enough for 640x480, but if it lags, check your internet connection speed to the cloud server.
110+
* **"Address already in use"**: Ensure no other python scripts are using ports 5555/5556 on either machine.

software/dashboard/app.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,32 @@
55
import zmq
66
import cv2
77
import numpy as np
8-
from flask import Flask, render_template, Response, jsonify
8+
from flask import Flask, render_template, Response, jsonify, request
99

1010
app = Flask(__name__)
1111

1212
# Global state
1313
latest_observation = {}
1414
lock = threading.Lock()
1515
connected = False
16+
recording = False
17+
cmd_socket = None
1618

17-
def zmq_worker(ip='127.0.0.1', port=5556):
18-
global latest_observation, connected
19+
def zmq_worker(ip='127.0.0.1', port=5556, cmd_port=5555):
20+
global latest_observation, connected, cmd_socket
1921
context = zmq.Context()
22+
23+
# Sub Socket
2024
socket = context.socket(zmq.SUB)
2125
socket.setsockopt(zmq.SUBSCRIBE, b"")
2226
socket.connect(f"tcp://{ip}:{port}")
2327
socket.setsockopt(zmq.CONFLATE, 1)
2428

29+
# Cmd Socket (Push to Sim)
30+
cmd_socket = context.socket(zmq.PUSH)
31+
cmd_socket.setsockopt(zmq.CONFLATE, 1)
32+
cmd_socket.connect(f"tcp://{ip}:{cmd_port}")
33+
2534
print(f"Connecting to ZMQ Stream at {ip}:{port}...")
2635

2736
while True:
@@ -93,6 +102,33 @@ def video_feed(camera_name):
93102
return Response(generate_frames(camera_name),
94103
mimetype='multipart/x-mixed-replace; boundary=frame')
95104

105+
@app.route('/api/command', methods=['POST'])
106+
def send_command():
107+
global cmd_socket
108+
if not request.json or 'command' not in request.json:
109+
return jsonify({'error': 'No command provided'}), 400
110+
111+
cmd = request.json['command']
112+
print(f"Received command: {cmd}")
113+
114+
# Example handling
115+
if cmd == 'reset_sim':
116+
# Send reset command (Isaac Sim needs to handle this logic)
117+
# For now, we can just zero out velocities or send a special flag
118+
if cmd_socket:
119+
cmd_socket.send_string(json.dumps({"reset": True}))
120+
121+
elif cmd == 'start_recording':
122+
# Trigger recording logic
123+
if cmd_socket:
124+
cmd_socket.send_string(json.dumps({"start_recording": True}))
125+
126+
elif cmd == 'stop_recording':
127+
if cmd_socket:
128+
cmd_socket.send_string(json.dumps({"stop_recording": True}))
129+
130+
return jsonify({'status': 'ok'})
131+
96132
@app.route('/api/status')
97133
def get_status():
98134
with lock:

software/dashboard/templates/index.html

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,27 @@
1111
.camera-box img { max-width: 100%; height: auto; display: block; }
1212
.camera-title { text-align: center; font-size: 0.9em; margin-bottom: 5px; }
1313
.status-panel { flex: 1; background: #333; padding: 15px; border-radius: 5px; min-width: 300px; }
14+
.controls-panel { width: 100%; background: #333; padding: 15px; border-radius: 5px; margin-top: 20px; }
1415
table { width: 100%; border-collapse: collapse; }
1516
td, th { padding: 5px; border-bottom: 1px solid #444; font-size: 0.9em; }
1617
th { text-align: left; color: #aaa; }
1718
.value { font-family: monospace; color: #4f4; }
19+
.btn { padding: 10px 20px; font-size: 16px; margin-right: 10px; cursor: pointer; background: #555; color: white; border: none; border-radius: 4px; }
20+
.btn:hover { background: #777; }
21+
.btn-record { background: #c00; }
22+
.btn-record:hover { background: #e00; }
23+
.status-indicator { display: inline-block; width: 10px; height: 10px; border-radius: 50%; margin-right: 5px; }
24+
.status-active { background: #0f0; }
25+
.status-inactive { background: #555; }
1826
</style>
1927
</head>
2028
<body>
21-
<h1>AlohaMini Dashboard</h1>
29+
<div style="display:flex; justify-content:space-between; align-items:center;">
30+
<h1>AlohaMini Dashboard</h1>
31+
<div>
32+
<span id="connectionStatus" class="status-indicator status-inactive"></span> <span id="connectionText">Disconnected</span>
33+
</div>
34+
</div>
2235

2336
<div class="container">
2437
<div class="camera-grid" id="cameraGrid">
@@ -33,14 +46,63 @@ <h3>Robot State</h3>
3346
</div>
3447
</div>
3548

49+
<div class="controls-panel">
50+
<h3>Simulation Controls</h3>
51+
<p>Control the remote simulation directly from here.</p>
52+
<button class="btn" onclick="sendCommand('reset_sim')">Reset Simulation</button>
53+
<button class="btn btn-record" id="recordBtn" onclick="toggleRecording()">Start Recording</button>
54+
<span id="recordingStatus" style="margin-left: 10px; color: #aaa;"></span>
55+
</div>
56+
3657
<script>
3758
const knownCameras = ['head_top', 'head_back', 'head_front', 'wrist_left', 'wrist_right', 'cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'];
3859
const activeCameras = new Set();
3960

61+
let isRecording = false;
62+
63+
function toggleRecording() {
64+
isRecording = !isRecording;
65+
const btn = document.getElementById('recordBtn');
66+
const status = document.getElementById('recordingStatus');
67+
68+
if (isRecording) {
69+
btn.textContent = "Stop Recording";
70+
btn.style.background = "#555";
71+
status.textContent = "Recording in progress...";
72+
sendCommand('start_recording');
73+
} else {
74+
btn.textContent = "Start Recording";
75+
btn.style.background = "#c00";
76+
status.textContent = "Saved.";
77+
sendCommand('stop_recording');
78+
setTimeout(() => status.textContent = "", 3000);
79+
}
80+
}
81+
82+
function sendCommand(cmd) {
83+
console.log("Sending command:", cmd);
84+
fetch('/api/command', {
85+
method: 'POST',
86+
headers: { 'Content-Type': 'application/json' },
87+
body: JSON.stringify({ command: cmd })
88+
});
89+
}
90+
4091
function updateStatus() {
4192
fetch('/api/status')
4293
.then(response => response.json())
4394
.then(data => {
95+
// Update connection indicator
96+
const connInd = document.getElementById('connectionStatus');
97+
const connText = document.getElementById('connectionText');
98+
if (data.connected) {
99+
connInd.className = 'status-indicator status-active';
100+
connText.textContent = "Connected";
101+
} else {
102+
connInd.className = 'status-indicator status-inactive';
103+
connText.textContent = "Disconnected";
104+
}
105+
44106
const table = document.getElementById('statusTable');
45107
let rows = '';
46108

software/examples/alohamini/isaac_sim/isaac_alohamini_env.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import base64
22
import json
33
import os
4+
import random
45
import sys
6+
import time
7+
from datetime import datetime
58

69
import cv2
710
import numpy as np
@@ -40,6 +43,54 @@
4043
# Locate URDF
4144
URDF_PATH = os.path.join(repo_root, "software/src/lerobot/robots/alohamini/alohamini.urdf")
4245

46+
class DatasetRecorder:
47+
def __init__(self, root_dir="data"):
48+
self.root_dir = root_dir
49+
self.is_recording = False
50+
self.current_episode_dir = None
51+
self.frame_idx = 0
52+
self.episode_idx = 0
53+
54+
if not os.path.exists(root_dir):
55+
os.makedirs(root_dir)
56+
57+
def start_recording(self):
58+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
59+
self.current_episode_dir = os.path.join(self.root_dir, f"episode_{timestamp}")
60+
os.makedirs(self.current_episode_dir)
61+
os.makedirs(os.path.join(self.current_episode_dir, "images"))
62+
self.frame_idx = 0
63+
self.is_recording = True
64+
print(f"Started recording to {self.current_episode_dir}")
65+
66+
def stop_recording(self):
67+
if self.is_recording:
68+
print(f"Stopped recording. Saved {self.frame_idx} frames.")
69+
self.is_recording = False
70+
self.current_episode_dir = None
71+
72+
def save_frame(self, obs, action):
73+
if not self.is_recording:
74+
return
75+
76+
# Save JSON data (state + action)
77+
data = {
78+
"timestamp": time.time(),
79+
"observation": {k: v for k, v in obs.items() if not isinstance(v, np.ndarray)},
80+
"action": action
81+
}
82+
83+
with open(os.path.join(self.current_episode_dir, f"frame_{self.frame_idx:06d}.json"), "w") as f:
84+
json.dump(data, f)
85+
86+
# Save Images
87+
for k, v in obs.items():
88+
if isinstance(v, np.ndarray): # Image
89+
img_path = os.path.join(self.current_episode_dir, "images", f"{k}_{self.frame_idx:06d}.jpg")
90+
cv2.imwrite(img_path, v)
91+
92+
self.frame_idx += 1
93+
4394
class IsaacAlohaMini:
4495
def __init__(self, world, urdf_path):
4596
self.world = world
@@ -188,6 +239,8 @@ def main():
188239

189240
print(f"Isaac Sim AlohaMini running. Ports: OBS={PORT_OBS}, CMD={PORT_CMD}")
190241

242+
recorder = DatasetRecorder(root_dir="data_sim")
243+
191244
while simulation_app.is_running():
192245
world.step(render=True)
193246

@@ -203,7 +256,23 @@ def main():
203256
joint_cmds = {}
204257
vx, vy, vth = 0, 0, 0
205258

259+
# Check for system commands
260+
if "start_recording" in cmd:
261+
recorder.start_recording()
262+
continue
263+
if "stop_recording" in cmd:
264+
recorder.stop_recording()
265+
continue
266+
206267
for k, v in cmd.items():
268+
if k == "reset" and v is True:
269+
# Reset
270+
print("Resetting robot...")
271+
world.reset()
272+
joint_cmds = {name: 0.0 for name in aloha.dof_names}
273+
aloha.set_joint_positions(joint_cmds)
274+
continue
275+
207276
if k.endswith(".pos"):
208277
joint_name = k.replace(".pos", "")
209278
joint_cmds[joint_name] = v
@@ -226,6 +295,20 @@ def main():
226295

227296
# 2. Get Obs & Publish
228297
obs = aloha.get_observations()
298+
299+
# Save frame if recording
300+
# (Pass current command as action label for now, though it's imperfect as it's the *commanded* not *measured* action)
301+
if recorder.is_recording:
302+
# Reconstruct action dict from parsed values
303+
# This is a simplification; ideally we log exactly what we sent
304+
action_log = {
305+
"x.vel": vx,
306+
"y.vel": vy,
307+
"theta.vel": vth,
308+
# Add arm joint targets if we had them easily accessible here
309+
}
310+
recorder.save_frame(obs, action_log)
311+
229312
encoded_obs = {}
230313

231314
# Process images

0 commit comments

Comments
 (0)