diff --git a/src/grasprl/grasprl/envs/grasp.py b/src/grasprl/grasprl/envs/grasp.py index 0d3cac698d..08d812201b 100644 --- a/src/grasprl/grasprl/envs/grasp.py +++ b/src/grasprl/grasprl/envs/grasp.py @@ -1,249 +1,164 @@ -from os import path -from collections import defaultdict import numpy as np +from collections import defaultdict +from gymnasium import spaces from controllers.operational_space_controller import OSC from controllers.joint_effort_controller import GripperEffortCtrl -from gymnasium import spaces from renderer.mujoco_env import MujocoPhyEnv -import random +_target_box = ["ball_3","ball_2","ball_1","box_2","box_1","box_3"] _right_finger_name = "right_finger" _left_finger_name = "left_finger" -_close_finger_dis = 0.06 -_open_finger_dis = 0.152 _grasp_target_num = 6 -_target_box = ["ball_3","ball_2","ball_1","box_2","box_1","box_3"] -eyehand_target = [-0.02,-0.13,1.45,0,0,1,1] class GraspRobot(MujocoPhyEnv): - def __init__(self, model_path="../worlds/grasp.xml", frame_skip=200, **kwargs): - xml_file_path = path.join(path.dirname(path.realpath(__file__)), model_path) - self.fullpath = xml_file_path - super().__init__(xml_file_path, frame_skip,** kwargs) - - self.IMAGE_WIDTH = 64 - self.IMAGE_HEIGHT = 64 + def __init__(self, model_path="worlds/grasp.xml", frame_skip=40, render_mode=None): + self.fullpath = model_path + super().__init__(model_path, frame_skip=frame_skip) + self.render_mode = render_mode + self.IMAGE_WIDTH, self.IMAGE_HEIGHT = 64, 64 self._set_observation_space() - self.info = {} self._set_action_space() self.tolerance = 0.005 self.drop_area = [0.6, 0.0, 1.15] - self.arm_joints_names = list(self.model_names.joint_names[:6]) + self.TABLE_HEIGHT = 0.9 + self.arm_joints_names = list(self.model_names.joint_names[:6]) self.arm_joints = [self.mjcf_model.find('joint', name) for name in self.arm_joints_names] self.eef_name = self.model_names.site_names[1] self.eef_site = self.mjcf_model.find('site', self.eef_name) - self.TABLE_HEIGHT = 1.0 self.controller = OSC( physics=self.physics, joints=self.arm_joints, eef_site=self.eef_site, - min_effort=-150.0, - max_effort=150.0, + min_effort=-150, max_effort=150, kp=80, ko=80, kv=50, - vmax_xyz=1.0, vmax_abg=2.0 + vmax_xyz=1, vmax_abg=2 ) self.grp_ctrl = GripperEffortCtrl(physics=self.physics, gripper=self.gripper) self.target_objects = _target_box + self.grasped_num = 0 + self.grasp_step = 0 - def before_grasp(self, show=False): - self.reward = 0 - self.get_image_data("eyeinhand", depth=True, show=show) - qpos = np.nan_to_num(self.physics.data.qpos.copy(), nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.qpos[:] = qpos - for _ in range(self.frame_skip): - self.controller.run(eyehand_target) - self.grp_ctrl.run(signal=0) - self.physics.data.qpos[:] = np.nan_to_num(self.physics.data.qpos, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.qacc[:] = np.nan_to_num(self.physics.data.qacc, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.ctrl[:] = np.nan_to_num(self.physics.data.ctrl, nan=0.0, posinf=1.0, neginf=-1.0) - self.step_mujoco_simulation() - rgb, depth = self.get_image_data("eyeinhand", depth=True, show=show) - self.observation["rgb"] = rgb - self.observation["depth"] = depth - self.info['grasp'] = "Failed" - self.info["move"] = "Failed" - - def after_grasp(self, show=False): - self.get_image_data("eyeinhand", depth=True, show=show) - for _ in range(self.frame_skip): - self.controller.run(eyehand_target) - self.grp_ctrl.run(signal=0) - self.step_mujoco_simulation() - rgb, depth = self.get_image_data("eyeinhand", depth=True, show=show) - self.observation["rgb"] = rgb - self.observation["depth"] = depth - - def move_eef(self, action): - success = False - target_pose = action.copy() + [0,0,1,1] - for _ in range(self.frame_skip): - self.controller.run(target_pose) - self.physics.data.qpos[:] = np.nan_to_num(self.physics.data.qpos, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.ctrl[:] = np.nan_to_num(self.physics.data.ctrl, nan=0.0, posinf=1.0, neginf=-1.0) - self.step_mujoco_simulation() - ee_pos = self.get_ee_pos() - if max(np.abs(ee_pos - action)) < self.tolerance: - success = True - if success: - self.info["move"] = f"move to target {action}" - return success + def _sanitize_physics_data(self): + for attr in ['qpos', 'qvel', 'ctrl', 'qacc']: + arr = getattr(self.physics.data, attr) + setattr(self.physics.data, attr, np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)) - def down_and_grasp(self, action): - down_success = False - target_pose = action.copy() - target_pose[2] -= 0.05 - target_pose += [0,0,1,1] - for _ in range(self.frame_skip): + def get_ee_pos(self): + return self.physics.bind(self.eef_site).xpos.copy() + + def get_body_com(self, body_name): + body_id = self.physics.model.name2id(body_name, 'body') + return self.physics.data.xpos[body_id].copy() + + def set_body_pos(self, body_name, pos): + body_id = self.physics.model.name2id(body_name, 'body') + self.physics.model.body_pos[body_id] = pos + + def world2pixel(self, cam_id, x, y, z): + fx = fy = 500 + cx = self.IMAGE_WIDTH / 2 + cy = self.IMAGE_HEIGHT / 2 + px = int((x * fx / z) + cx) + py = int((y * fy / z) + cy) + return px, py + + def pixel2world(self, cam_id, px, py, depth): + x = (px / self.IMAGE_WIDTH - 0.5) * 0.48 + y = (py / self.IMAGE_HEIGHT - 0.5) * 0.48 + z = depth + return np.array([x, y, z], dtype=np.float32) + + def _set_action_space(self): + self.action_space = spaces.Box(low=-0.25, high=0.25, shape=[3], dtype=np.float32) + + def _set_observation_space(self): + self.observation = defaultdict() + self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT, 3), dtype=np.float32) + self.observation["depth"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT), dtype=np.float32) + + def move_eef(self, target): + if hasattr(target, "tolist"): + target = target.tolist() + target_pose = target + [0, 0, 1, 1] + current_frame_skip = self.frame_skip if np.linalg.norm(np.array(self.get_ee_pos()) - np.array(target)) > 0.1 else 20 + for _ in range(current_frame_skip): self.controller.run(target_pose) + self._sanitize_physics_data() self.step_mujoco_simulation() - if max(np.abs(self.get_ee_pos() - action)) < self.tolerance: - down_success = True - if down_success: - for _ in range(self.frame_skip): - self.controller.run(target_pose) + if np.allclose(self.get_ee_pos(), target, atol=self.tolerance): + return True + return False + + def down_and_grasp(self, target): + down_pose = target.copy() + down_pose[2] -= 0.04 + success = self.move_eef(down_pose) + if success: + for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=1) self.step_mujoco_simulation() - return down_success + return success def move_up_drop(self): - success = False up_pose = list(self.get_ee_pos()) - up_pose[2] += 0.1 - up_pose += [0,0,1,1] - drop_pose = self.drop_area + [0,0,1,1] - dist = np.linalg.norm(self.get_ee_pos() - self.get_body_com(self.target_objects[0])) - self.reward = -0.01 * dist - - for _ in range(self.frame_skip): - self.controller.run(up_pose) - self.step_mujoco_simulation() - - if self.check_grasp_success(): - self.info["grasp"] = "Success" + up_pose[2] += 0.12 + self.move_eef(up_pose) + grasp_success = self.check_grasp_success() + if grasp_success: self.grasped_num += 1 - self.reward = 1 - for _ in range(self.frame_skip): - self.controller.run(drop_pose) + for _ in range(self.frame_skip // 2): + self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() - if max(np.abs(self.get_ee_pos() - self.drop_area)) < self.tolerance: - success = True - if success: - for _ in range(self.frame_skip): - self.controller.run(drop_pose) - self.grp_ctrl.run(signal=0) - self.step_mujoco_simulation() - return success - - def check_terminated(self): - for box in _target_box: - if self.get_body_com(box)[2] >= self.TABLE_HEIGHT: - return False - return True + return grasp_success def check_grasp_success(self): right = self.get_body_com(_right_finger_name) left = self.get_body_com(_left_finger_name) - dist = max(np.abs(right - left)) - return dist < 0.12 + dist = np.linalg.norm(right - left) + return dist < 0.16 def open_gripper(self): - target_pose = list(self.get_ee_pos()) + [0,0,1,1] - for _ in range(self.frame_skip): - self.controller.run(target_pose) + for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() - def move_and_grasp(self, action): - self.open_gripper() - action[2] = 1.18 - - if self.move_eef(action): - if self.down_and_grasp(action): - self.move_up_drop() - else: - self.open_gripper() - - def _set_action_space(self): - self.action_space1 = spaces.Box(low=-0.25, high=0.25, shape=[2]) - - def _set_observation_space(self): - self.observation = defaultdict() - self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT, 3)) - self.observation["depth"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT)) - def step(self, action): - self.terminated = False self.info = {} - self.before_grasp(show=False) - self.move_and_grasp(action) - self.after_grasp(show=False) - + self.open_gripper() + moved = self.move_eef(action) + grasped = self.down_and_grasp(action) if moved else False + success = self.move_up_drop() if grasped else False + ee_pos = self.get_ee_pos() obj_pos = self.get_body_com(self.target_objects[0]) dist = np.linalg.norm(ee_pos - obj_pos) - reward = -0.02 * dist - - if self.info.get("grasp") == "Success": - reward += 10.0 - elif self.info.get("move") == "Failed": - reward -= 1.0 - - self.reward = reward - - if self.grasped_num == _grasp_target_num or self.grasp_step == 5: - self.terminated = True - if self.check_terminated(): - self.terminated = True - - self.grasp_step += 1 - return self.observation, self.reward, self.terminated, self.info + reward = 0.5 - dist * 2.0 + + if success: + reward += 15.0 + self.info["grasp"] = "Success" + else: + self.info["grasp"] = "Failed" + + if not moved: + reward -= 0.2 - def reset_object(self): - for box_name in _target_box: - self.set_body_pos(box_name) + self.grasp_step += 1 + done = self.grasped_num == _grasp_target_num or self.grasp_step >= 18 + return self.observation, reward, done, self.info def reset(self): super().reset() - self.reset_object() self.grasped_num = 0 self.grasp_step = 0 - self.info["completion"] = "Failed" self.open_gripper() - self.before_grasp(show=False) return self.observation def reset_without_random(self): super().reset() self.grasped_num = 0 self.grasp_step = 0 - self.info["completion"] = "Failed" self.open_gripper() - self.before_grasp(show=False) - return self.observation - - def get_ee_pos(self): - return self.physics.bind(self.eef_site).xpos.copy() - - def pixel2world(self, cam_id, px, py, depth): - fx = fy = 500 - cx = self.IMAGE_WIDTH / 2 - cy = self.IMAGE_HEIGHT / 2 - x = (px - cx) * depth / fx - y = (py - cy) * depth / fy - return [x, y, depth] - - def world2pixel(self, cam_id, x, y, z): - fx = fy = 500 - cx = self.IMAGE_WIDTH / 2 - cy = self.IMAGE_HEIGHT / 2 - px = int((x * fx / z) + cx) - py = int((y * fy / z) + cy) - return px, py - - def set_body_pos(self, body_name, pos=None): - if pos is None: - pos = [random.uniform(-0.2, 0.2), random.uniform(-0.2, 0.2), self.TABLE_HEIGHT + 0.05] - body_id = self.physics.model.name2id(body_name, 'body') - self.physics.model.body_pos[body_id] = pos \ No newline at end of file + return self.observation \ No newline at end of file diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py index 5b55a2e655..c5bd8fa4bd 100644 --- a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py @@ -1,116 +1,43 @@ import os import numpy as np -# 改成相对路径 data_path = os.path.dirname(os.path.abspath(__file__)) -need_fields = ["grasp_success"] -report_name = "dataset_check.txt" - -def check_files(): - total_rgb = 0 - total_label = 0 - no_label = [] - no_rgb = [] - label_err = [] - lack_field = [] - - rgb_list = [] - label_list = [] - for file in os.listdir(data_path): - full_path = os.path.join(data_path, file) - if file[:4] == "rgb_" and file[-4:] == ".png": - rgb_list.append(file) - total_rgb += 1 - elif file[:6] == "label_" and file[-4:] == ".npy": - label_list.append(file) - total_label += 1 - - rgb_idx = set() - for f in rgb_list: - try: - idx = int(f.replace("rgb_", "").replace(".png", "")) - rgb_idx.add(idx) - except: - label_err.append(f"{f}:文件名不对,得是rgb_数字.png这种格式") - - label_idx = set() - for f in label_list: - try: - idx = int(f.replace("label_", "").replace(".npy", "")) - label_idx.add(idx) - except: - label_err.append(f"{f}:标签文件名不对,得是label_数字.npy") - - for f in rgb_list: - try: - idx = int(f.replace("rgb_", "").replace(".png", "")) - if idx not in label_idx: - no_label.append(f) - except: - pass - - for f in label_list: - try: - idx = int(f.replace("label_", "").replace(".npy", "")) - if idx not in rgb_idx: - no_rgb.append(f) - except: - pass - - for label_file in label_list: - label_full = os.path.join(data_path, label_file) - try: - label_data = np.load(label_full, allow_pickle=True).item() - if type(label_data) != dict: - label_err.append(f"{label_file}:不是字典格式,存的时候要转字典!") - continue - lack = [] - for field in need_fields: - if field not in label_data: - lack.append(field) - if lack: - lack_field.append(f"{label_file}:少字段{lack}") - except: - label_err.append(f"{label_file}:文件坏了或者加载失败") - - report = [] - report.append("数据集检查结果") - report.append("------------") - report.append(f"RGB文件总数:{total_rgb}") - report.append(f"Label文件总数:{total_label}\n") - - report.append(f"有RGB但没Label的文件({len(no_label)}个):") - for f in no_label[:10]: - report.append(f" - {f}") - if len(no_label) > 10: - report.append(f" - 还有{len(no_label)-10}个没显示") - report.append("") - - report.append(f"有Label但没RGB的文件({len(no_rgb)}个):") - for f in no_rgb[:10]: - report.append(f" - {f}") - if len(no_rgb) > 10: - report.append(f" - 还有{len(no_rgb)-10}个没显示") - report.append("") - - report.append(f"Label格式错误({len(label_err)}个):") - for f in label_err[:10]: - report.append(f" - {f}") - if len(label_err) > 10: - report.append(f" - 还有{len(label_err)-10}个没显示") - report.append("") - - report.append(f"Label缺字段({len(lack_field)}个):") - for f in lack_field[:10]: - report.append(f" - {f}") - if len(lack_field) > 10: - report.append(f" - 还有{len(lack_field)-10}个没显示") - report.append("") - - with open(os.path.join(data_path, report_name), "w", encoding="utf-8") as f: - f.write('\n'.join(report)) - print('\n'.join(report)) - print(f"\n报告已经存到{data_path}里的{report_name}了") - -if __name__ == "__main__": - check_files() \ No newline at end of file +need_fields = ["grasp_success", "grasp_force", "grasp_pose"] + +total = 0 +lack = 0 +error_files = [] +label_err = [] + +files = [] +for f in os.listdir(data_path): + if f.endswith(".npy"): + files.append(f) + +for f in files: + total += 1 + p = os.path.join(data_path, f) + try: + d = np.load(p, allow_pickle=True) + if isinstance(d, np.ndarray): + d = d.item() + for field in need_fields: + if field not in d: + lack += 1 + label_err.append(f"{f} 缺少 {field}") + if "grasp_force" in d: + v = d["grasp_force"] + if not (0 <= v <= 100): + label_err.append(f"{f} grasp_force 异常") + if "grasp_pose" in d: + pose = d["grasp_pose"] + if not isinstance(pose, list) or len(pose) != 3: + label_err.append(f"{f} grasp_pose 异常") + except: + error_files.append(f) + +print("检查完成") +print("总文件", total) +print("缺失字段数", lack) +print("错误文件", len(error_files)) +print("数据异常", len(label_err)) \ No newline at end of file diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py index 07acce64ed..1333637868 100644 --- a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py @@ -1,36 +1,36 @@ import os import numpy as np -#改成相对路径 data_path = os.path.dirname(os.path.abspath(__file__)) -success_num = 0 -total_num = 0 -error_files = [] -label_files = [f for f in os.listdir(data_path) if f.startswith("label_") and f.endswith(".npy")] +success = 0 +fail = 0 +forces = [] +sforces = [] +fforces = [] -print("===== 各Label的grasp_success值 =====") -for f in label_files: - total_num += 1 +files = [f for f in os.listdir(data_path) if f.endswith(".npy")] + +for f in files: + p = os.path.join(data_path, f) try: - label = np.load(os.path.join(data_path, f), allow_pickle=True).item() - val = label.get("grasp_success", -1) - print(f"{f} → {val}") - if val == 1: - success_num += 1 + d = np.load(p, allow_pickle=True) + if isinstance(d, np.ndarray): + d = d.item() + s = d.get("grasp_success", 0) + force = d.get("grasp_force", 50) + forces.append(force) + if s == 1: + success += 1 + sforces.append(force) + else: + fail += 1 + fforces.append(force) except: - error_files.append(f) - print(f"{f} → 读取失败") - -print("\n===== 统计结果 =====") -print(f"总数量:{total_num}") -print(f"成功数:{success_num}") -print(f"失败数:{total_num - success_num - len(error_files)}") -print(f"读取失败:{len(error_files)}") -if total_num > 0: - print(f"成功率:{success_num/total_num*100:.2f}%") + continue -if error_files: - print("\n===== 失败文件 =====") - for f in error_files: - print(f"- {f}") \ No newline at end of file +print("抓取成功", success) +print("抓取失败", fail) +print("平均力", np.mean(forces) if forces else 0) +print("成功平均力", np.mean(sforces) if sforces else 0) +print("失败平均力", np.mean(fforces) if fforces else 0) \ No newline at end of file diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py new file mode 100644 index 0000000000..c69ea8ba54 --- /dev/null +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py @@ -0,0 +1,56 @@ +import os +import numpy as np + +data_path = os.path.dirname(os.path.abspath(__file__)) + +clean_rules = { + "grasp_success": {"valid_range": [0, 1], "default": 0}, + "grasp_force": {"valid_range": [0, 100], "default": 50}, + "grasp_pose": {"valid_type": list, "valid_length": 3, "default": [0, 0, 0]} +} + +clean_count = 0 +total_count = 0 +error_files = [] + +files = [f for f in os.listdir(data_path) if f.endswith(".npy")] + +log = [] + +for f in files: + total_count += 1 + p = os.path.join(data_path, f) + try: + data = np.load(p, allow_pickle=True) + if isinstance(data, np.ndarray): + data = data.item() + modified = False + for field, rule in clean_rules.items(): + if field not in data: + data[field] = rule["default"] + modified = True + continue + val = data[field] + if "valid_range" in rule: + minv, maxv = rule["valid_range"] + if not (minv <= val <= maxv): + data[field] = rule["default"] + modified = True + if "valid_type" in rule: + if not isinstance(val, rule["valid_type"]): + data[field] = rule["default"] + modified = True + if "valid_length" in rule: + if len(val) != rule["valid_length"]: + data[field] = rule["default"] + modified = True + if modified: + np.save(p, data, allow_pickle=True) + clean_count += 1 + except: + error_files.append(f) + +print("清洗完成") +print("总样本", total_count) +print("已修复", clean_count) +print("错误文件", len(error_files)) \ No newline at end of file diff --git a/src/grasprl/trainer/dqn_baseline.py b/src/grasprl/trainer/dqn_baseline.py new file mode 100644 index 0000000000..2491726282 --- /dev/null +++ b/src/grasprl/trainer/dqn_baseline.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import random +import math +import cv2 +import os +from tqdm import tqdm +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from envs.grasp import GraspRobot +from modules.ddpg import ReplayBuffer, Transition + +IMAGE_SIZE = 32 +ACTION_SIZE = IMAGE_SIZE * IMAGE_SIZE +BATCH_SIZE = 64 +MEM_SIZE = 10000 +EPS_START = 1.0 +EPS_END = 0.05 +EPS_DECAY = 800 +GAMMA = 0.99 +LR = 0.001 +TARGET_UPDATE = 100 + +class VisualFeatureEnhancer(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(4, 16, 3, padding=1) + self.conv2 = nn.Conv2d(16, 4, 3, padding=1) + self.relu = nn.ReLU() + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.conv2(x) + return x + +class DQN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(4, 32, 3, padding=1) + self.conv2 = nn.Conv2d(32, 64, 3, padding=1) + self.conv3 = nn.Conv2d(64, 128, 3, padding=1) + self.fc = nn.Linear(128 * IMAGE_SIZE * IMAGE_SIZE, ACTION_SIZE) + self.relu = nn.ReLU() + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + return self.fc(x) + +class DQN_Trainer: + def __init__(self, render_mode="human"): + self.env = GraspRobot(render_mode=render_mode) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.memory = ReplayBuffer(MEM_SIZE, simple=False) + self.policy_net = DQN().to(self.device) + self.target_net = DQN().to(self.device) + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LR) + self.criterion = nn.MSELoss() + self.steps_done = 0 + self.save_dir = os.path.join("grasprl", "dataset", "grasp_samples") + os.makedirs(self.save_dir, exist_ok=True) + self.enhancer = VisualFeatureEnhancer().to(self.device).eval() + + def transform_state(self, obs): + depth = obs["depth"].astype(np.float32) + depth = depth.max() - depth + if depth.max() > 0: + depth = depth / depth.max() + rgb = obs["rgb"].astype(np.float32) / 255.0 + rgb_t = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).to(self.device) + depth_t = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).to(self.device) + state = torch.cat([rgb_t, depth_t], dim=1).float() + with torch.no_grad(): + state = self.enhancer(state) + return state + + def transform_action(self, action_idx, depth_img): + idx = action_idx.item() + px = idx % IMAGE_SIZE + py = idx // IMAGE_SIZE + px = max(0, min(px, IMAGE_SIZE-1)) + py = max(0, min(py, IMAGE_SIZE-1)) + depth_val = depth_img[py][px] if depth_img[py][px] > 0 else np.mean(depth_img) + action = self.env.pixel2world(1, px, py, depth_val) + action = np.clip(action, [-0.25, -0.25, 1.05], [0.25, 0.25, 1.3]) + return action.tolist() + + def select_action(self, state): + eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * self.steps_done / EPS_DECAY) + self.steps_done += 1 + if random.random() > eps: + with torch.no_grad(): + q_vals = self.policy_net(state) + action = q_vals.argmax().item() + return torch.tensor([[action]], device=self.device) + else: + return torch.tensor([[random.randint(0, ACTION_SIZE-1)]], device=self.device) + + def learn(self): + if len(self.memory) < BATCH_SIZE: + return + transitions = self.memory.sample(BATCH_SIZE) + batch = Transition(*zip(*transitions)) + s = torch.cat(batch.state).to(self.device) + a = torch.cat(batch.action).to(self.device) + r = torch.cat(batch.reward).to(self.device) + ns = torch.cat(batch.next_state).to(self.device) + q_values = self.policy_net(s).gather(1, a) + with torch.no_grad(): + next_q = self.target_net(ns).max(1, keepdim=True)[0] + target = r + GAMMA * next_q + loss = self.criterion(q_values, target) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if self.steps_done % TARGET_UPDATE == 0: + self.target_net.load_state_dict(self.policy_net.state_dict()) + + def save_sample(self, action, reward, info, i): + rgb = self.env.observation["rgb"] + cv2.imwrite(f"{self.save_dir}/rgb_{i}.png", cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)) + label = {"grasp_success": 1 if info["grasp"]=="Success" else 0, "reward": reward} + np.save(f"{self.save_dir}/label_{i}.npy", label) + +def main(): + max_episodes = 200 + trainer = DQN_Trainer(render_mode="human") + success_count = 0 + loop = tqdm(range(1, max_episodes+1)) + for ep in loop: + obs = trainer.env.reset_without_random() + state = trainer.transform_state(obs) + done = False + step = 0 + ep_reward = 0 + while not done and step < 10: + action_idx = trainer.select_action(state) + action = trainer.transform_action(action_idx, obs["depth"]) + next_obs, reward, done, info = trainer.env.step(action) + next_state = trainer.transform_state(next_obs) + trainer.memory.push(state, action_idx, torch.tensor([[reward]], device=trainer.device), next_state) + state = next_state + obs = next_obs + ep_reward += reward + step += 1 + trainer.learn() + if info.get("grasp") == "Success": + success_count += 1 + loop.set_postfix(success_rate=f"{success_count/ep:.2f}", ep_reward=f"{ep_reward:.2f}") + trainer.save_sample(action, reward, info, ep) + os.makedirs("grasprl/trained", exist_ok=True) + torch.save(trainer.policy_net.state_dict(), "grasprl/trained/dqn_final.pth") + trainer.env.close() + +if __name__ == "__main__": + main() \ No newline at end of file