-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemonstrate.py
More file actions
161 lines (121 loc) · 5.82 KB
/
demonstrate.py
File metadata and controls
161 lines (121 loc) · 5.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from rlbench.action_modes import ArmActionMode
from rlbench.backend.observation import Observation
from utils.utils import scale_panda_pose
from utils.utils import blank_image_list
from utils.utils import step_images
from config import EndToEndConfig
import numpy as np
import matplotlib.pyplot as plt
def get_image(obs: Observation, pov: str) -> np.ndarray:
"""
Gets depth image from an observation based on the point of view.
:param obs: RLBench observation at a given time step
:param pov: String of what point of view to return an image of
:return: 128x128x4 depth image as a numpy array.
"""
if pov == "wrist":
image = np.dstack((obs.wrist_rgb, obs.wrist_depth))
elif pov == "front":
image = np.dstack((obs.wrist_rgb, obs.wrist_depth))
else:
image = np.dstack((obs.wrist_rgb, obs.wrist_depth))
return image
def main():
print('[Info] Starting demonstrate.py')
config = EndToEndConfig()
network_dir = config.get_trained_network_dir()
network, network_info, _ = config.load_trained_network(network_dir)
print(f'\n[Info] Finished loading the network, {network_info.network_name}.')
parsed_network_name = network_info.network_name.split('_')
task_name, imitation_task = config.get_task_from_name(parsed_network_name)
num_demonstrations = int(input('\nEnter how many demonstrations to perform (default 5): ') or 5)
demonstration_episode_length = 100 # max steps per episode
env = config.get_env(randomized=True)
env.launch()
task = env.get_task(imitation_task)
plt.ion()
fig, axs = plt.subplots(nrows=3, ncols=1, sharex='all')
fig.suptitle('Error of Gripper and Target Estimates')
fig.text(0.5, 0.04, 'Time Step', ha='center')
fig.text(0.04, 0.5, 'Difference [m]', va='center', rotation='vertical')
lines = []
for ax, d in zip(axs, ['X', 'Y', 'Z']):
ax.set_title(f'{d}-Direction')
line_t, = ax.plot(0, 0, 'b-', label='Target')
line_g, = ax.plot(0, 0, 'r-', label='Gripper')
ax.plot(np.linspace(0, demonstration_episode_length, demonstration_episode_length),
np.zeros(demonstration_episode_length), 'k-', label='Zero')
lines.append((line_t, line_g))
axs[0].legend(loc='lower left')
evaluation_steps = num_demonstrations * demonstration_episode_length
image_list = blank_image_list(network_info.num_images)
obs = None
for i in range(evaluation_steps):
if i % demonstration_episode_length == 0: # i.e. we're starting a new demonstration
step = 0
target_error = []
gripper_error = []
steps = []
descriptions, obs = task.reset()
image_list = blank_image_list(network_info.num_images)
print(f"[Info] Task reset: on episode {int(1+(i/demonstration_episode_length))} "
f"of {int(evaluation_steps / demonstration_episode_length)}")
input('Press enter to continue...')
##############################################################
# Collect prediction information from the latest observation #
##############################################################
image = get_image(obs, network_info.pov)
image_list = step_images(image_list, image)
image_input = np.expand_dims(np.dstack(image_list), 0)
gripper_input = np.expand_dims(obs.gripper_open, 0)
joints_input = scale_panda_pose(obs.joint_positions, 'down') # to [0, 1] for prediction
joints_input = np.expand_dims(joints_input, 0)
#######################
# Make the prediction #
#######################
prediction = network.predict(x=[joints_input, gripper_input, image_input])
##########################################################
# Parse prediction for the actions and auxiliary outputs #
##########################################################
joint_action = prediction[0].flatten()
if config.rlbench_actionmode.arm == ArmActionMode.ABS_JOINT_POSITION:
joint_action = scale_panda_pose(joint_action, 'up') # from [0, 1] to joint's proper values
gripper_action = np.argmax(prediction[1].flatten())
target_estimation = prediction[2].flatten()
gripper_estimation = prediction[3].flatten()
try:
target_actual = task._task.cup.get_pose()
except NameError:
print('[Error] Unable to find target. Returning infinity as position')
target_actual = np.array([np.inf, np.inf, np.inf])
gripper_actual = obs.gripper_pose
#################
# Create Graphs #
#################
target_error.append(target_estimation - target_actual[:3])
gripper_error.append(gripper_estimation - gripper_actual[:3])
step += 1
steps.append(step)
t_e = np.array(target_error)
g_e = np.array(gripper_error)
for s, ax in enumerate(axs):
line_t, line_g = lines[s]
line_t.set_data(steps, t_e[:, s])
line_g.set_data(steps, g_e[:, s])
fig.canvas.draw()
fig.canvas.flush_events()
# print(f'steps={steps}\nt_e={t_e[:, s]}')
ax.set_xlim(left=0, right=step)
ax.set_ylim(bottom=min(min(g_e[:, s]), min(t_e[:, s])), top=max(max(g_e[:, s]), max(t_e[:, s])))
#######################################################
# Create action input and step the simulation forward #
#######################################################
action = np.append(joint_action, gripper_action)
obs, reward, terminate = task.step(action)
input('Press enter to exit...')
env.shutdown()
print(f'[Info] Successfully exiting program.')
if __name__ == '__main__':
main()