From 47bc6814b77adfbcf79be919c51331efec94e906 Mon Sep 17 00:00:00 2001 From: WK5605 <3031585504@qq.com> Date: Thu, 30 Apr 2026 17:02:36 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=9C=BA=E6=A2=B0=E6=8A=93=E5=8F=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E4=BC=98=E5=8C=96=EF=BC=9A=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E3=80=81=E7=BB=9F=E8=AE=A1=E3=80=81=E6=B8=85=E6=B4=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dataset/grasp_samples/clean_grasp_data.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py new file mode 100644 index 0000000000..ec00a4cfd6 --- /dev/null +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py @@ -0,0 +1,60 @@ +import os +import numpy as np + +data_path = os.path.dirname(os.path.abspath(__file__)) + +clean_rules = { + "grasp_success": {"valid_range": [0,1], "default": 0}, + "grasp_force": {"valid_range": [0, 100], "default": 50}, + "grasp_pose": {"valid_type": list, "default": [0,0,0]} +} + +clean_count = 0 +total_count = 0 +error_files = [] + +label_files = [f for f in os.listdir(data_path) if f.startswith("label_") and f.endswith(".npy")] + +print("===== 机械抓取数据清洗 =====") +for f in label_files: + total_count += 1 + label_path = os.path.join(data_path, f) + + try: + label = np.load(label_path, allow_pickle=True).item() + modified = False + + for field, rule in clean_rules.items(): + if field not in label: + label[field] = rule["default"] + modified = True + print(f"{f} → 缺失 {field},已填充") + continue + + val = label[field] + + if "valid_range" in rule: + vmin, vmax = rule["valid_range"] + if not (vmin <= val <= vmax): + label[field] = rule["default"] + modified = True + print(f"{f} → {field} 异常,已修正") + + if "valid_type" in rule: + if not isinstance(val, rule["valid_type"]): + label[field] = rule["default"] + modified = True + print(f"{f} → {field} 类型错误,已修正") + + if modified: + np.save(label_path, label, allow_pickle=True) + clean_count += 1 + + except: + error_files.append(f) + print(f"{f} → 读取失败") + +print("\n清洗完成") +print(f"总样本:{total_count}") +print(f"已修正:{clean_count}") +print(f"异常文件:{len(error_files)}") \ No newline at end of file From bad9bef26f25e8d84ed0a65fea0155044d12043b Mon Sep 17 00:00:00 2001 From: WK5605 <3031585504@qq.com> Date: Fri, 1 May 2026 22:44:12 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=9C=BA=E6=A2=B0=E6=8A=93=E5=8F=96DQN?= =?UTF-8?q?=E6=A8=A1=E6=8B=9F=E5=99=A8=E4=BF=AE=E5=A4=8D=E4=B8=8E=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/grasprl/grasprl/envs/grasp.py | 265 +++++++++------------------- src/grasprl/trainer/dqn_baseline.py | 160 +++++++++++++++++ 2 files changed, 240 insertions(+), 185 deletions(-) create mode 100644 src/grasprl/trainer/dqn_baseline.py diff --git a/src/grasprl/grasprl/envs/grasp.py b/src/grasprl/grasprl/envs/grasp.py index 0d3cac698d..50ef0dd407 100644 --- a/src/grasprl/grasprl/envs/grasp.py +++ b/src/grasprl/grasprl/envs/grasp.py @@ -1,249 +1,144 @@ -from os import path -from collections import defaultdict import numpy as np +from collections import defaultdict +from gymnasium import spaces from controllers.operational_space_controller import OSC from controllers.joint_effort_controller import GripperEffortCtrl -from gymnasium import spaces from renderer.mujoco_env import MujocoPhyEnv -import random +_target_box = ["ball_3","ball_2","ball_1","box_2","box_1","box_3"] _right_finger_name = "right_finger" _left_finger_name = "left_finger" -_close_finger_dis = 0.06 -_open_finger_dis = 0.152 _grasp_target_num = 6 -_target_box = ["ball_3","ball_2","ball_1","box_2","box_1","box_3"] -eyehand_target = [-0.02,-0.13,1.45,0,0,1,1] class GraspRobot(MujocoPhyEnv): - def __init__(self, model_path="../worlds/grasp.xml", frame_skip=200, **kwargs): - xml_file_path = path.join(path.dirname(path.realpath(__file__)), model_path) - self.fullpath = xml_file_path - super().__init__(xml_file_path, frame_skip,** kwargs) - - self.IMAGE_WIDTH = 64 - self.IMAGE_HEIGHT = 64 + def __init__(self, model_path="worlds/grasp.xml", frame_skip=50, render_mode=None): + self.fullpath = model_path + super().__init__(model_path, frame_skip=frame_skip) + self.render_mode = render_mode + self.IMAGE_WIDTH, self.IMAGE_HEIGHT = 64,64 self._set_observation_space() - self.info = {} self._set_action_space() self.tolerance = 0.005 - self.drop_area = [0.6, 0.0, 1.15] - self.arm_joints_names = list(self.model_names.joint_names[:6]) + self.drop_area = [0.6,0.0,1.15] + self.TABLE_HEIGHT = 0.9 + self.arm_joints_names = list(self.model_names.joint_names[:6]) self.arm_joints = [self.mjcf_model.find('joint', name) for name in self.arm_joints_names] self.eef_name = self.model_names.site_names[1] self.eef_site = self.mjcf_model.find('site', self.eef_name) - self.TABLE_HEIGHT = 1.0 self.controller = OSC( physics=self.physics, joints=self.arm_joints, eef_site=self.eef_site, - min_effort=-150.0, - max_effort=150.0, + min_effort=-150, max_effort=150, kp=80, ko=80, kv=50, - vmax_xyz=1.0, vmax_abg=2.0 + vmax_xyz=1, vmax_abg=2 ) self.grp_ctrl = GripperEffortCtrl(physics=self.physics, gripper=self.gripper) self.target_objects = _target_box + self.grasped_num = 0 + self.grasp_step = 0 - def before_grasp(self, show=False): - self.reward = 0 - self.get_image_data("eyeinhand", depth=True, show=show) - qpos = np.nan_to_num(self.physics.data.qpos.copy(), nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.qpos[:] = qpos - for _ in range(self.frame_skip): - self.controller.run(eyehand_target) - self.grp_ctrl.run(signal=0) - self.physics.data.qpos[:] = np.nan_to_num(self.physics.data.qpos, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.qacc[:] = np.nan_to_num(self.physics.data.qacc, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.ctrl[:] = np.nan_to_num(self.physics.data.ctrl, nan=0.0, posinf=1.0, neginf=-1.0) - self.step_mujoco_simulation() - rgb, depth = self.get_image_data("eyeinhand", depth=True, show=show) - self.observation["rgb"] = rgb - self.observation["depth"] = depth - self.info['grasp'] = "Failed" - self.info["move"] = "Failed" - - def after_grasp(self, show=False): - self.get_image_data("eyeinhand", depth=True, show=show) - for _ in range(self.frame_skip): - self.controller.run(eyehand_target) - self.grp_ctrl.run(signal=0) - self.step_mujoco_simulation() - rgb, depth = self.get_image_data("eyeinhand", depth=True, show=show) - self.observation["rgb"] = rgb - self.observation["depth"] = depth - - def move_eef(self, action): - success = False - target_pose = action.copy() + [0,0,1,1] - for _ in range(self.frame_skip): - self.controller.run(target_pose) - self.physics.data.qpos[:] = np.nan_to_num(self.physics.data.qpos, nan=0.0, posinf=0.0, neginf=0.0) - self.physics.data.ctrl[:] = np.nan_to_num(self.physics.data.ctrl, nan=0.0, posinf=1.0, neginf=-1.0) - self.step_mujoco_simulation() - ee_pos = self.get_ee_pos() - if max(np.abs(ee_pos - action)) < self.tolerance: - success = True - if success: - self.info["move"] = f"move to target {action}" - return success + def _sanitize_physics_data(self): + for attr in ['qpos','qvel','ctrl','qacc']: + setattr(self.physics.data, attr, np.nan_to_num(getattr(self.physics.data, attr), nan=0.0, posinf=0.0, neginf=0.0)) + + def get_ee_pos(self): + return self.physics.bind(self.eef_site).xpos.copy() + + def _set_action_space(self): + self.action_space = spaces.Box(low=-0.25, high=0.25, shape=[3], dtype=np.float32) - def down_and_grasp(self, action): - down_success = False - target_pose = action.copy() - target_pose[2] -= 0.05 - target_pose += [0,0,1,1] - for _ in range(self.frame_skip): + def _set_observation_space(self): + self.observation = defaultdict() + self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH,self.IMAGE_HEIGHT,3), dtype=np.float32) + self.observation["depth"] = np.zeros((self.IMAGE_WIDTH,self.IMAGE_HEIGHT), dtype=np.float32) + + def move_eef(self, target): + if hasattr(target, "tolist"): + target = target.tolist() + target_pose = target + [0,0,1,1] + current_frame_skip = self.frame_skip if np.linalg.norm(np.array(self.get_ee_pos()) - np.array(target)) > 0.1 else 20 + for _ in range(current_frame_skip): self.controller.run(target_pose) + self._sanitize_physics_data() self.step_mujoco_simulation() - if max(np.abs(self.get_ee_pos() - action)) < self.tolerance: - down_success = True - if down_success: - for _ in range(self.frame_skip): - self.controller.run(target_pose) + if np.allclose(self.get_ee_pos(), target, atol=self.tolerance): + return True + return False + + def down_and_grasp(self, target): + down_pose = target.copy() + down_pose[2] -= 0.05 + success = self.move_eef(down_pose) + if success: + for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=1) self.step_mujoco_simulation() - return down_success + return success def move_up_drop(self): - success = False up_pose = list(self.get_ee_pos()) up_pose[2] += 0.1 - up_pose += [0,0,1,1] drop_pose = self.drop_area + [0,0,1,1] - dist = np.linalg.norm(self.get_ee_pos() - self.get_body_com(self.target_objects[0])) - self.reward = -0.01 * dist - - for _ in range(self.frame_skip): - self.controller.run(up_pose) - self.step_mujoco_simulation() - - if self.check_grasp_success(): - self.info["grasp"] = "Success" + self.move_eef(up_pose) + grasp_success = self.check_grasp_success() + if grasp_success: self.grasped_num += 1 - self.reward = 1 - for _ in range(self.frame_skip): - self.controller.run(drop_pose) + self.move_eef(drop_pose) + for _ in range(self.frame_skip // 2): + self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() - if max(np.abs(self.get_ee_pos() - self.drop_area)) < self.tolerance: - success = True - if success: - for _ in range(self.frame_skip): - self.controller.run(drop_pose) - self.grp_ctrl.run(signal=0) - self.step_mujoco_simulation() - return success - - def check_terminated(self): - for box in _target_box: - if self.get_body_com(box)[2] >= self.TABLE_HEIGHT: - return False - return True + return grasp_success def check_grasp_success(self): - right = self.get_body_com(_right_finger_name) - left = self.get_body_com(_left_finger_name) - dist = max(np.abs(right - left)) + dist = np.linalg.norm(self.get_body_com(_right_finger_name) - self.get_body_com(_left_finger_name)) return dist < 0.12 def open_gripper(self): - target_pose = list(self.get_ee_pos()) + [0,0,1,1] - for _ in range(self.frame_skip): - self.controller.run(target_pose) + for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() - def move_and_grasp(self, action): - self.open_gripper() - action[2] = 1.18 - - if self.move_eef(action): - if self.down_and_grasp(action): - self.move_up_drop() - else: - self.open_gripper() - - def _set_action_space(self): - self.action_space1 = spaces.Box(low=-0.25, high=0.25, shape=[2]) - - def _set_observation_space(self): - self.observation = defaultdict() - self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT, 3)) - self.observation["depth"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT)) + # 修复 pixel2world,避免 cam_mat 报错 + def pixel2world(self, cam_id, px, py, depth): + x = (px / self.IMAGE_WIDTH - 0.5) * 0.5 + y = (py / self.IMAGE_HEIGHT - 0.5) * 0.5 + z = depth + return np.array([x, y, z], dtype=np.float32) def step(self, action): - self.terminated = False self.info = {} - self.before_grasp(show=False) - self.move_and_grasp(action) - self.after_grasp(show=False) - - ee_pos = self.get_ee_pos() - obj_pos = self.get_body_com(self.target_objects[0]) - dist = np.linalg.norm(ee_pos - obj_pos) - reward = -0.02 * dist - - if self.info.get("grasp") == "Success": - reward += 10.0 - elif self.info.get("move") == "Failed": - reward -= 1.0 - - self.reward = reward - - if self.grasped_num == _grasp_target_num or self.grasp_step == 5: - self.terminated = True - if self.check_terminated(): - self.terminated = True - - self.grasp_step += 1 - return self.observation, self.reward, self.terminated, self.info + self.open_gripper() + moved = self.move_eef(action) + grasped = self.down_and_grasp(action) if moved else False + success = self.move_up_drop() if grasped else False - def reset_object(self): - for box_name in _target_box: - self.set_body_pos(box_name) + dist = np.linalg.norm(self.get_ee_pos() - self.get_body_com(self.target_objects[0])) + reward = 1.0 - np.tanh(3 * dist) + if success: + reward += 20 + self.info["grasp"] = "Success" + else: + self.info["grasp"] = "Failed" + if not moved: + reward -= 0.5 + + self.grasp_step += 1 + done = self.grasped_num == _grasp_target_num or self.grasp_step >= 10 + return self.observation, reward, done, self.info def reset(self): super().reset() - self.reset_object() self.grasped_num = 0 self.grasp_step = 0 - self.info["completion"] = "Failed" self.open_gripper() - self.before_grasp(show=False) return self.observation def reset_without_random(self): super().reset() self.grasped_num = 0 self.grasp_step = 0 - self.info["completion"] = "Failed" self.open_gripper() - self.before_grasp(show=False) - return self.observation - - def get_ee_pos(self): - return self.physics.bind(self.eef_site).xpos.copy() - - def pixel2world(self, cam_id, px, py, depth): - fx = fy = 500 - cx = self.IMAGE_WIDTH / 2 - cy = self.IMAGE_HEIGHT / 2 - x = (px - cx) * depth / fx - y = (py - cy) * depth / fy - return [x, y, depth] - - def world2pixel(self, cam_id, x, y, z): - fx = fy = 500 - cx = self.IMAGE_WIDTH / 2 - cy = self.IMAGE_HEIGHT / 2 - px = int((x * fx / z) + cx) - py = int((y * fy / z) + cy) - return px, py - - def set_body_pos(self, body_name, pos=None): - if pos is None: - pos = [random.uniform(-0.2, 0.2), random.uniform(-0.2, 0.2), self.TABLE_HEIGHT + 0.05] - body_id = self.physics.model.name2id(body_name, 'body') - self.physics.model.body_pos[body_id] = pos \ No newline at end of file + return self.observation \ No newline at end of file diff --git a/src/grasprl/trainer/dqn_baseline.py b/src/grasprl/trainer/dqn_baseline.py new file mode 100644 index 0000000000..f8bb9f7553 --- /dev/null +++ b/src/grasprl/trainer/dqn_baseline.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import random +import math +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from envs.grasp import GraspRobot +from modules.ddpg import ReplayBuffer, Transition +from modules.qnet import MULTIDISCRETE_RESNET + +BATCH_SIZE = 16 +GAMMA = 0.95 +LR = 0.0005 + +class VisualFeatureEnhancer(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(4,16,3,padding=1) + self.conv2 = nn.Conv2d(16,4,3,padding=1) + self.relu = nn.ReLU() + self.pool = nn.MaxPool2d(2) + + def forward(self,x): + x = x.float() + x = self.relu(self.conv1(x)) + x = self.pool(x) + x = self.conv2(x) + return torch.clamp(x,-1,1) + +class DQN_Trainer: + def __init__(self, lr=LR, mem_size=10000, eps_start=0.9, eps_end=0.05, eps_decay=5000, + seed=42, log_dir="test", render_mode=None): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.env = GraspRobot(render_mode=render_mode) + self.memory = ReplayBuffer(mem_size) + self.writer = SummaryWriter(f"grasprl/log/DQN/{log_dir}") + self.eps_start,self.eps_end,self.eps_decay = eps_start, eps_end, eps_decay + self.steps_done = 0 + + self.q_net = MULTIDISCRETE_RESNET(1).to(self.device,dtype=torch.float32) + self.feat_enhance = VisualFeatureEnhancer().to(self.device,dtype=torch.float32) + self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr, eps=1e-4) + self.criterion = nn.SmoothL1Loss(reduction="mean").to(self.device) + + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + def transform_state(self,state): + depth = np.max(state["depth"]) - state["depth"] + depth = (depth - np.min(depth)) / (np.max(depth)-np.min(depth)+1e-8) + depth = depth.astype(np.float32) + + rgb = state["rgb"].astype(np.float32)/255.0 + tensor_rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).to(self.device,dtype=torch.float32) + tensor_depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).to(self.device,dtype=torch.float32) + + tensor_obs = torch.cat([tensor_rgb, tensor_depth], dim=1).float() + with torch.no_grad(): + tensor_obs = self.feat_enhance(tensor_obs) + return tensor_obs + + def _get_eps(self): + eps = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1.*self.steps_done/self.eps_decay) + self.steps_done += 1 + if self.steps_done % 100 == 0: + self.writer.add_scalar("Epsilon",eps,self.steps_done) + return eps + + def select_action(self,state): + eps = self._get_eps() + if random.random() > eps: + with torch.no_grad(): + q_vals = self.q_net(state).view(-1) # flatten H*W + max_idx = torch.argmax(q_vals) + return max_idx.unsqueeze(0).unsqueeze(0) # shape [1,1] + else: + valid_idx = list(range(self.env.IMAGE_WIDTH*self.env.IMAGE_HEIGHT)) + return torch.tensor([[random.choice(valid_idx)]],dtype=torch.long,device=self.device) + + def transform_action(self,max_idx,depth_before): + idx = max_idx.item() + px = idx % self.env.IMAGE_WIDTH + py = idx // self.env.IMAGE_WIDTH + px = np.clip(px,0,self.env.IMAGE_WIDTH-1) + py = np.clip(py,0,self.env.IMAGE_HEIGHT-1) + depth = depth_before[py][px] if depth_before[py][px]>0 else np.mean(depth_before) + return self.env.pixel2world(1,px,py,depth) + + def limit_action(self,action): + return np.clip(action, + [-0.25,-0.25,self.env.TABLE_HEIGHT+0.05], + [0.25,0.25,2.0]).astype(np.float32).tolist() + + def learn(self): + if len(self.memory) Date: Tue, 5 May 2026 17:14:44 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=9F=BA=E4=BA=8E=20DQN=20=E7=9A=84?= =?UTF-8?q?=E6=9C=BA=E6=A2=B0=E8=87=82=E8=A7=86=E8=A7=89=E6=8A=93=E5=8F=96?= =?UTF-8?q?=E5=BC=BA=E5=8C=96=E5=AD=A6=E4=B9=A0=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/grasprl/grasprl/envs/grasp.py | 72 +++-- .../dataset/grasp_samples/check_dataset.py | 151 +++-------- .../grasp_samples/check_success_value.py | 54 ++-- .../dataset/grasp_samples/clean_grasp_data.py | 56 ++-- src/grasprl/trainer/dqn_baseline.py | 247 +++++++++--------- 5 files changed, 262 insertions(+), 318 deletions(-) diff --git a/src/grasprl/grasprl/envs/grasp.py b/src/grasprl/grasprl/envs/grasp.py index 50ef0dd407..08d812201b 100644 --- a/src/grasprl/grasprl/envs/grasp.py +++ b/src/grasprl/grasprl/envs/grasp.py @@ -11,15 +11,15 @@ _grasp_target_num = 6 class GraspRobot(MujocoPhyEnv): - def __init__(self, model_path="worlds/grasp.xml", frame_skip=50, render_mode=None): + def __init__(self, model_path="worlds/grasp.xml", frame_skip=40, render_mode=None): self.fullpath = model_path super().__init__(model_path, frame_skip=frame_skip) self.render_mode = render_mode - self.IMAGE_WIDTH, self.IMAGE_HEIGHT = 64,64 + self.IMAGE_WIDTH, self.IMAGE_HEIGHT = 64, 64 self._set_observation_space() self._set_action_space() self.tolerance = 0.005 - self.drop_area = [0.6,0.0,1.15] + self.drop_area = [0.6, 0.0, 1.15] self.TABLE_HEIGHT = 0.9 self.arm_joints_names = list(self.model_names.joint_names[:6]) @@ -41,24 +41,47 @@ def __init__(self, model_path="worlds/grasp.xml", frame_skip=50, render_mode=Non self.grasp_step = 0 def _sanitize_physics_data(self): - for attr in ['qpos','qvel','ctrl','qacc']: - setattr(self.physics.data, attr, np.nan_to_num(getattr(self.physics.data, attr), nan=0.0, posinf=0.0, neginf=0.0)) + for attr in ['qpos', 'qvel', 'ctrl', 'qacc']: + arr = getattr(self.physics.data, attr) + setattr(self.physics.data, attr, np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)) def get_ee_pos(self): return self.physics.bind(self.eef_site).xpos.copy() + def get_body_com(self, body_name): + body_id = self.physics.model.name2id(body_name, 'body') + return self.physics.data.xpos[body_id].copy() + + def set_body_pos(self, body_name, pos): + body_id = self.physics.model.name2id(body_name, 'body') + self.physics.model.body_pos[body_id] = pos + + def world2pixel(self, cam_id, x, y, z): + fx = fy = 500 + cx = self.IMAGE_WIDTH / 2 + cy = self.IMAGE_HEIGHT / 2 + px = int((x * fx / z) + cx) + py = int((y * fy / z) + cy) + return px, py + + def pixel2world(self, cam_id, px, py, depth): + x = (px / self.IMAGE_WIDTH - 0.5) * 0.48 + y = (py / self.IMAGE_HEIGHT - 0.5) * 0.48 + z = depth + return np.array([x, y, z], dtype=np.float32) + def _set_action_space(self): self.action_space = spaces.Box(low=-0.25, high=0.25, shape=[3], dtype=np.float32) def _set_observation_space(self): self.observation = defaultdict() - self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH,self.IMAGE_HEIGHT,3), dtype=np.float32) - self.observation["depth"] = np.zeros((self.IMAGE_WIDTH,self.IMAGE_HEIGHT), dtype=np.float32) + self.observation["rgb"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT, 3), dtype=np.float32) + self.observation["depth"] = np.zeros((self.IMAGE_WIDTH, self.IMAGE_HEIGHT), dtype=np.float32) def move_eef(self, target): if hasattr(target, "tolist"): target = target.tolist() - target_pose = target + [0,0,1,1] + target_pose = target + [0, 0, 1, 1] current_frame_skip = self.frame_skip if np.linalg.norm(np.array(self.get_ee_pos()) - np.array(target)) > 0.1 else 20 for _ in range(current_frame_skip): self.controller.run(target_pose) @@ -70,7 +93,7 @@ def move_eef(self, target): def down_and_grasp(self, target): down_pose = target.copy() - down_pose[2] -= 0.05 + down_pose[2] -= 0.04 success = self.move_eef(down_pose) if success: for _ in range(self.frame_skip // 2): @@ -80,34 +103,27 @@ def down_and_grasp(self, target): def move_up_drop(self): up_pose = list(self.get_ee_pos()) - up_pose[2] += 0.1 - drop_pose = self.drop_area + [0,0,1,1] + up_pose[2] += 0.12 self.move_eef(up_pose) grasp_success = self.check_grasp_success() if grasp_success: self.grasped_num += 1 - self.move_eef(drop_pose) for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() return grasp_success def check_grasp_success(self): - dist = np.linalg.norm(self.get_body_com(_right_finger_name) - self.get_body_com(_left_finger_name)) - return dist < 0.12 + right = self.get_body_com(_right_finger_name) + left = self.get_body_com(_left_finger_name) + dist = np.linalg.norm(right - left) + return dist < 0.16 def open_gripper(self): for _ in range(self.frame_skip // 2): self.grp_ctrl.run(signal=0) self.step_mujoco_simulation() - # 修复 pixel2world,避免 cam_mat 报错 - def pixel2world(self, cam_id, px, py, depth): - x = (px / self.IMAGE_WIDTH - 0.5) * 0.5 - y = (py / self.IMAGE_HEIGHT - 0.5) * 0.5 - z = depth - return np.array([x, y, z], dtype=np.float32) - def step(self, action): self.info = {} self.open_gripper() @@ -115,18 +131,22 @@ def step(self, action): grasped = self.down_and_grasp(action) if moved else False success = self.move_up_drop() if grasped else False - dist = np.linalg.norm(self.get_ee_pos() - self.get_body_com(self.target_objects[0])) - reward = 1.0 - np.tanh(3 * dist) + ee_pos = self.get_ee_pos() + obj_pos = self.get_body_com(self.target_objects[0]) + dist = np.linalg.norm(ee_pos - obj_pos) + reward = 0.5 - dist * 2.0 + if success: - reward += 20 + reward += 15.0 self.info["grasp"] = "Success" else: self.info["grasp"] = "Failed" + if not moved: - reward -= 0.5 + reward -= 0.2 self.grasp_step += 1 - done = self.grasped_num == _grasp_target_num or self.grasp_step >= 10 + done = self.grasped_num == _grasp_target_num or self.grasp_step >= 18 return self.observation, reward, done, self.info def reset(self): diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py index 5b55a2e655..c5bd8fa4bd 100644 --- a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_dataset.py @@ -1,116 +1,43 @@ import os import numpy as np -# 改成相对路径 data_path = os.path.dirname(os.path.abspath(__file__)) -need_fields = ["grasp_success"] -report_name = "dataset_check.txt" - -def check_files(): - total_rgb = 0 - total_label = 0 - no_label = [] - no_rgb = [] - label_err = [] - lack_field = [] - - rgb_list = [] - label_list = [] - for file in os.listdir(data_path): - full_path = os.path.join(data_path, file) - if file[:4] == "rgb_" and file[-4:] == ".png": - rgb_list.append(file) - total_rgb += 1 - elif file[:6] == "label_" and file[-4:] == ".npy": - label_list.append(file) - total_label += 1 - - rgb_idx = set() - for f in rgb_list: - try: - idx = int(f.replace("rgb_", "").replace(".png", "")) - rgb_idx.add(idx) - except: - label_err.append(f"{f}:文件名不对,得是rgb_数字.png这种格式") - - label_idx = set() - for f in label_list: - try: - idx = int(f.replace("label_", "").replace(".npy", "")) - label_idx.add(idx) - except: - label_err.append(f"{f}:标签文件名不对,得是label_数字.npy") - - for f in rgb_list: - try: - idx = int(f.replace("rgb_", "").replace(".png", "")) - if idx not in label_idx: - no_label.append(f) - except: - pass - - for f in label_list: - try: - idx = int(f.replace("label_", "").replace(".npy", "")) - if idx not in rgb_idx: - no_rgb.append(f) - except: - pass - - for label_file in label_list: - label_full = os.path.join(data_path, label_file) - try: - label_data = np.load(label_full, allow_pickle=True).item() - if type(label_data) != dict: - label_err.append(f"{label_file}:不是字典格式,存的时候要转字典!") - continue - lack = [] - for field in need_fields: - if field not in label_data: - lack.append(field) - if lack: - lack_field.append(f"{label_file}:少字段{lack}") - except: - label_err.append(f"{label_file}:文件坏了或者加载失败") - - report = [] - report.append("数据集检查结果") - report.append("------------") - report.append(f"RGB文件总数:{total_rgb}") - report.append(f"Label文件总数:{total_label}\n") - - report.append(f"有RGB但没Label的文件({len(no_label)}个):") - for f in no_label[:10]: - report.append(f" - {f}") - if len(no_label) > 10: - report.append(f" - 还有{len(no_label)-10}个没显示") - report.append("") - - report.append(f"有Label但没RGB的文件({len(no_rgb)}个):") - for f in no_rgb[:10]: - report.append(f" - {f}") - if len(no_rgb) > 10: - report.append(f" - 还有{len(no_rgb)-10}个没显示") - report.append("") - - report.append(f"Label格式错误({len(label_err)}个):") - for f in label_err[:10]: - report.append(f" - {f}") - if len(label_err) > 10: - report.append(f" - 还有{len(label_err)-10}个没显示") - report.append("") - - report.append(f"Label缺字段({len(lack_field)}个):") - for f in lack_field[:10]: - report.append(f" - {f}") - if len(lack_field) > 10: - report.append(f" - 还有{len(lack_field)-10}个没显示") - report.append("") - - with open(os.path.join(data_path, report_name), "w", encoding="utf-8") as f: - f.write('\n'.join(report)) - print('\n'.join(report)) - print(f"\n报告已经存到{data_path}里的{report_name}了") - -if __name__ == "__main__": - check_files() \ No newline at end of file +need_fields = ["grasp_success", "grasp_force", "grasp_pose"] + +total = 0 +lack = 0 +error_files = [] +label_err = [] + +files = [] +for f in os.listdir(data_path): + if f.endswith(".npy"): + files.append(f) + +for f in files: + total += 1 + p = os.path.join(data_path, f) + try: + d = np.load(p, allow_pickle=True) + if isinstance(d, np.ndarray): + d = d.item() + for field in need_fields: + if field not in d: + lack += 1 + label_err.append(f"{f} 缺少 {field}") + if "grasp_force" in d: + v = d["grasp_force"] + if not (0 <= v <= 100): + label_err.append(f"{f} grasp_force 异常") + if "grasp_pose" in d: + pose = d["grasp_pose"] + if not isinstance(pose, list) or len(pose) != 3: + label_err.append(f"{f} grasp_pose 异常") + except: + error_files.append(f) + +print("检查完成") +print("总文件", total) +print("缺失字段数", lack) +print("错误文件", len(error_files)) +print("数据异常", len(label_err)) \ No newline at end of file diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py index 07acce64ed..1333637868 100644 --- a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/check_success_value.py @@ -1,36 +1,36 @@ import os import numpy as np -#改成相对路径 data_path = os.path.dirname(os.path.abspath(__file__)) -success_num = 0 -total_num = 0 -error_files = [] -label_files = [f for f in os.listdir(data_path) if f.startswith("label_") and f.endswith(".npy")] +success = 0 +fail = 0 +forces = [] +sforces = [] +fforces = [] -print("===== 各Label的grasp_success值 =====") -for f in label_files: - total_num += 1 +files = [f for f in os.listdir(data_path) if f.endswith(".npy")] + +for f in files: + p = os.path.join(data_path, f) try: - label = np.load(os.path.join(data_path, f), allow_pickle=True).item() - val = label.get("grasp_success", -1) - print(f"{f} → {val}") - if val == 1: - success_num += 1 + d = np.load(p, allow_pickle=True) + if isinstance(d, np.ndarray): + d = d.item() + s = d.get("grasp_success", 0) + force = d.get("grasp_force", 50) + forces.append(force) + if s == 1: + success += 1 + sforces.append(force) + else: + fail += 1 + fforces.append(force) except: - error_files.append(f) - print(f"{f} → 读取失败") - -print("\n===== 统计结果 =====") -print(f"总数量:{total_num}") -print(f"成功数:{success_num}") -print(f"失败数:{total_num - success_num - len(error_files)}") -print(f"读取失败:{len(error_files)}") -if total_num > 0: - print(f"成功率:{success_num/total_num*100:.2f}%") + continue -if error_files: - print("\n===== 失败文件 =====") - for f in error_files: - print(f"- {f}") \ No newline at end of file +print("抓取成功", success) +print("抓取失败", fail) +print("平均力", np.mean(forces) if forces else 0) +print("成功平均力", np.mean(sforces) if sforces else 0) +print("失败平均力", np.mean(fforces) if fforces else 0) \ No newline at end of file diff --git a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py index ec00a4cfd6..c69ea8ba54 100644 --- a/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py +++ b/src/grasprl/grasprl/grasprl/dataset/grasp_samples/clean_grasp_data.py @@ -4,57 +4,53 @@ data_path = os.path.dirname(os.path.abspath(__file__)) clean_rules = { - "grasp_success": {"valid_range": [0,1], "default": 0}, + "grasp_success": {"valid_range": [0, 1], "default": 0}, "grasp_force": {"valid_range": [0, 100], "default": 50}, - "grasp_pose": {"valid_type": list, "default": [0,0,0]} + "grasp_pose": {"valid_type": list, "valid_length": 3, "default": [0, 0, 0]} } clean_count = 0 total_count = 0 error_files = [] -label_files = [f for f in os.listdir(data_path) if f.startswith("label_") and f.endswith(".npy")] +files = [f for f in os.listdir(data_path) if f.endswith(".npy")] -print("===== 机械抓取数据清洗 =====") -for f in label_files: +log = [] + +for f in files: total_count += 1 - label_path = os.path.join(data_path, f) - + p = os.path.join(data_path, f) try: - label = np.load(label_path, allow_pickle=True).item() + data = np.load(p, allow_pickle=True) + if isinstance(data, np.ndarray): + data = data.item() modified = False - for field, rule in clean_rules.items(): - if field not in label: - label[field] = rule["default"] + if field not in data: + data[field] = rule["default"] modified = True - print(f"{f} → 缺失 {field},已填充") continue - - val = label[field] - + val = data[field] if "valid_range" in rule: - vmin, vmax = rule["valid_range"] - if not (vmin <= val <= vmax): - label[field] = rule["default"] + minv, maxv = rule["valid_range"] + if not (minv <= val <= maxv): + data[field] = rule["default"] modified = True - print(f"{f} → {field} 异常,已修正") - if "valid_type" in rule: if not isinstance(val, rule["valid_type"]): - label[field] = rule["default"] + data[field] = rule["default"] + modified = True + if "valid_length" in rule: + if len(val) != rule["valid_length"]: + data[field] = rule["default"] modified = True - print(f"{f} → {field} 类型错误,已修正") - if modified: - np.save(label_path, label, allow_pickle=True) + np.save(p, data, allow_pickle=True) clean_count += 1 - except: error_files.append(f) - print(f"{f} → 读取失败") -print("\n清洗完成") -print(f"总样本:{total_count}") -print(f"已修正:{clean_count}") -print(f"异常文件:{len(error_files)}") \ No newline at end of file +print("清洗完成") +print("总样本", total_count) +print("已修复", clean_count) +print("错误文件", len(error_files)) \ No newline at end of file diff --git a/src/grasprl/trainer/dqn_baseline.py b/src/grasprl/trainer/dqn_baseline.py index f8bb9f7553..2491726282 100644 --- a/src/grasprl/trainer/dqn_baseline.py +++ b/src/grasprl/trainer/dqn_baseline.py @@ -4,157 +4,158 @@ import numpy as np import random import math -from torch.utils.tensorboard import SummaryWriter +import cv2 +import os from tqdm import tqdm +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from envs.grasp import GraspRobot from modules.ddpg import ReplayBuffer, Transition -from modules.qnet import MULTIDISCRETE_RESNET -BATCH_SIZE = 16 -GAMMA = 0.95 -LR = 0.0005 +IMAGE_SIZE = 32 +ACTION_SIZE = IMAGE_SIZE * IMAGE_SIZE +BATCH_SIZE = 64 +MEM_SIZE = 10000 +EPS_START = 1.0 +EPS_END = 0.05 +EPS_DECAY = 800 +GAMMA = 0.99 +LR = 0.001 +TARGET_UPDATE = 100 class VisualFeatureEnhancer(nn.Module): def __init__(self): super().__init__() - self.conv1 = nn.Conv2d(4,16,3,padding=1) - self.conv2 = nn.Conv2d(16,4,3,padding=1) + self.conv1 = nn.Conv2d(4, 16, 3, padding=1) + self.conv2 = nn.Conv2d(16, 4, 3, padding=1) self.relu = nn.ReLU() - self.pool = nn.MaxPool2d(2) - - def forward(self,x): - x = x.float() + def forward(self, x): x = self.relu(self.conv1(x)) - x = self.pool(x) x = self.conv2(x) - return torch.clamp(x,-1,1) + return x + +class DQN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(4, 32, 3, padding=1) + self.conv2 = nn.Conv2d(32, 64, 3, padding=1) + self.conv3 = nn.Conv2d(64, 128, 3, padding=1) + self.fc = nn.Linear(128 * IMAGE_SIZE * IMAGE_SIZE, ACTION_SIZE) + self.relu = nn.ReLU() + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + return self.fc(x) class DQN_Trainer: - def __init__(self, lr=LR, mem_size=10000, eps_start=0.9, eps_end=0.05, eps_decay=5000, - seed=42, log_dir="test", render_mode=None): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def __init__(self, render_mode="human"): self.env = GraspRobot(render_mode=render_mode) - self.memory = ReplayBuffer(mem_size) - self.writer = SummaryWriter(f"grasprl/log/DQN/{log_dir}") - self.eps_start,self.eps_end,self.eps_decay = eps_start, eps_end, eps_decay + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.memory = ReplayBuffer(MEM_SIZE, simple=False) + self.policy_net = DQN().to(self.device) + self.target_net = DQN().to(self.device) + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LR) + self.criterion = nn.MSELoss() self.steps_done = 0 - - self.q_net = MULTIDISCRETE_RESNET(1).to(self.device,dtype=torch.float32) - self.feat_enhance = VisualFeatureEnhancer().to(self.device,dtype=torch.float32) - self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr, eps=1e-4) - self.criterion = nn.SmoothL1Loss(reduction="mean").to(self.device) - - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - def transform_state(self,state): - depth = np.max(state["depth"]) - state["depth"] - depth = (depth - np.min(depth)) / (np.max(depth)-np.min(depth)+1e-8) - depth = depth.astype(np.float32) - - rgb = state["rgb"].astype(np.float32)/255.0 - tensor_rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).to(self.device,dtype=torch.float32) - tensor_depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).to(self.device,dtype=torch.float32) - - tensor_obs = torch.cat([tensor_rgb, tensor_depth], dim=1).float() + self.save_dir = os.path.join("grasprl", "dataset", "grasp_samples") + os.makedirs(self.save_dir, exist_ok=True) + self.enhancer = VisualFeatureEnhancer().to(self.device).eval() + + def transform_state(self, obs): + depth = obs["depth"].astype(np.float32) + depth = depth.max() - depth + if depth.max() > 0: + depth = depth / depth.max() + rgb = obs["rgb"].astype(np.float32) / 255.0 + rgb_t = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).to(self.device) + depth_t = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).to(self.device) + state = torch.cat([rgb_t, depth_t], dim=1).float() with torch.no_grad(): - tensor_obs = self.feat_enhance(tensor_obs) - return tensor_obs - - def _get_eps(self): - eps = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1.*self.steps_done/self.eps_decay) + state = self.enhancer(state) + return state + + def transform_action(self, action_idx, depth_img): + idx = action_idx.item() + px = idx % IMAGE_SIZE + py = idx // IMAGE_SIZE + px = max(0, min(px, IMAGE_SIZE-1)) + py = max(0, min(py, IMAGE_SIZE-1)) + depth_val = depth_img[py][px] if depth_img[py][px] > 0 else np.mean(depth_img) + action = self.env.pixel2world(1, px, py, depth_val) + action = np.clip(action, [-0.25, -0.25, 1.05], [0.25, 0.25, 1.3]) + return action.tolist() + + def select_action(self, state): + eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * self.steps_done / EPS_DECAY) self.steps_done += 1 - if self.steps_done % 100 == 0: - self.writer.add_scalar("Epsilon",eps,self.steps_done) - return eps - - def select_action(self,state): - eps = self._get_eps() if random.random() > eps: with torch.no_grad(): - q_vals = self.q_net(state).view(-1) # flatten H*W - max_idx = torch.argmax(q_vals) - return max_idx.unsqueeze(0).unsqueeze(0) # shape [1,1] + q_vals = self.policy_net(state) + action = q_vals.argmax().item() + return torch.tensor([[action]], device=self.device) else: - valid_idx = list(range(self.env.IMAGE_WIDTH*self.env.IMAGE_HEIGHT)) - return torch.tensor([[random.choice(valid_idx)]],dtype=torch.long,device=self.device) - - def transform_action(self,max_idx,depth_before): - idx = max_idx.item() - px = idx % self.env.IMAGE_WIDTH - py = idx // self.env.IMAGE_WIDTH - px = np.clip(px,0,self.env.IMAGE_WIDTH-1) - py = np.clip(py,0,self.env.IMAGE_HEIGHT-1) - depth = depth_before[py][px] if depth_before[py][px]>0 else np.mean(depth_before) - return self.env.pixel2world(1,px,py,depth) - - def limit_action(self,action): - return np.clip(action, - [-0.25,-0.25,self.env.TABLE_HEIGHT+0.05], - [0.25,0.25,2.0]).astype(np.float32).tolist() + return torch.tensor([[random.randint(0, ACTION_SIZE-1)]], device=self.device) def learn(self): - if len(self.memory)