Skip to content

Some performance problem when visualizing the predicted actions and reference one #7

@oym1994

Description

@oym1994

5

Good Result

16

Good Result

12

Not good result

169

Not good result

Thanks for your great job.

I have trained with the dataset "hand_wiping_1-14_5actiongap_10000points.hdf5" for about 3000 epochs with the provided config. Then I write some code to visualize the predicted action and reference action (Using the same dataset for validation). I found that most results looks pretty well but there also exists some bad output, which are shown as above where green curve refers to the reference trajectory and red curve refers to the predicted action(trajectory). I wander whether it's in your expectation(in my expectation, the predicted output should be very close to the reference when using the same dataset for training and validation)?

The visualization coda are as below and hope it can help for validation before deploying on any real robot:

Code for saving the reference trajectory and predicted trajectory

import json
import torch
import robomimic.utils.obs_utils as ObsUtils
from robomimic.config import config_factory
import robomimic.utils.file_utils as FileUtils
import numpy as np
import os
import h5py
import robomimic.utils.train_utils as TrainUtils
from torch.utils.data import DataLoader


def vector_to_action_dict(action: np.ndarray, action_shapes: dict[str, int], action_keys: list[str]) -> dict[
    str, np.ndarray]:
    action_dict = dict()
    start_idx = 0
    for key in action_keys:
        this_act_shape = action_shapes[key]
        this_act_dim = np.prod(this_act_shape)
        end_idx = start_idx + this_act_dim
        action_dict[key] = action[:, start_idx:end_idx]
        start_idx = end_idx
    return action_dict


# load model
infer_device = "cuda:0"
checkpoint_path = "~/diffusion_policy_pcd_wiping_1-14/20240530170743/models/model_epoch_2300.pth"

algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=checkpoint_path)
dp_eval_steps = 10

if dp_eval_steps is not None:
    # HACK: modify the config, then dump to json again and write to ckpt_dict
    tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
    with tmp_config.values_unlocked():
        if tmp_config.algo.ddpm.enabled:
            tmp_config.algo.ddpm.num_inference_timesteps = dp_eval_steps
        elif tmp_config.algo.ddim.enabled:
            tmp_config.algo.ddim.num_inference_timesteps = dp_eval_steps
        else:
            raise Exception("should not reach here")
ckpt_dict['config'] = tmp_config.dump()

    # restore policy
model, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=infer_device, verbose=True)


config_file = "......../training_config/diffusion_policy_pcd_wiping_1-14.json"

ext_cfg = json.load(open(config_file, 'r'))
config = config_factory(ext_cfg["algo_name"])
with config.values_unlocked():
    config.update(ext_cfg)

# read config to set up metadata for observation modalities (e.g. detecting rgb observations)
ObsUtils.initialize_obs_utils_with_config(config)
# make sure the dataset exists
eval_dataset_cfg = config.train.data[0]
dataset_path = os.path.expandvars(os.path.expanduser(eval_dataset_cfg["path"]))
ds_format = config.train.data_format
if not os.path.exists(dataset_path):
    raise Exception("Dataset at provided path {} not found!".format(dataset_path))
shape_meta = FileUtils.get_shape_metadata_from_dataset(
        dataset_path=dataset_path,
        action_keys=config.train.action_keys,
        all_obs_keys=config.all_obs_keys,
        ds_format=ds_format,
        verbose=True
    )
trainset, validset = TrainUtils.load_data_for_training(
        config, obs_keys=shape_meta["all_obs_keys"])
train_sampler = trainset.get_dataset_sampler()

# initialize data loaders
train_loader = DataLoader(
    dataset=trainset,
    sampler=train_sampler,
    batch_size=config.train.batch_size,
    shuffle= False, #(train_sampler is None),
    num_workers=config.train.num_data_workers,
    drop_last=True
)

expected_traj = []
output_traj = []

data_loader_iter = iter(train_loader)
num_steps = min(len(train_loader), 200)
To = 3

with torch.no_grad():
    for i in range(num_steps):
        batch = next(data_loader_iter)
        obs = {}
        for k, v in batch['obs'].items():
            obs[k] = torch.squeeze(v, dim=0)
        obs = {k: obs[k][:To, :] for k in obs}
        output_action_numpy = model(obs)
        action_shapes = {"eef_position": 3, "eef_quaternion": 4, "gripper": 1}
        action_keys = ["eef_position", "eef_quaternion", "gripper"]
        action_dict = vector_to_action_dict(output_action_numpy, action_shapes, action_keys)
        expected_traj.append(np.squeeze(np.asarray(batch['actions'])))
        output_traj.append(action_dict['eef_position'])

    save_file = "~/dexil_inference_with_trainset_debug-wiping-1-14.hdf5"

    with h5py.File(save_file, 'w') as output_hdf5:
        for idx in range(len(expected_traj)):
            traj = output_hdf5.create_group(f"traj_{idx}")
            traj.create_dataset("output_traj", data = output_traj[idx])
            traj.create_dataset("expected_traj", data = expected_traj[idx])
        output_hdf5.attrs["num_samples"] = len(output_traj)

Code for visualizing the trajectories

import h5py
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

file = "~/dexil_inference_with_trainset_debug-wiping-1-14.hdf5"
f = h5py.File(file,'r')
traj_num = f.attrs['num_samples']
# new a figure and set it into 3d
fig = plt.figure()
ax = fig.add_axes(Axes3D(fig))
# set figure information
ax.set_title("3D_Curve")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
print("start to polt trajectories!")

idx = 0
save_dir = "..../debug_images/"

for k,v in f.items():
    expected_traj = v['expected_traj']
    output_traj = v['output_traj']
    expected_x = expected_traj[:, 0]
    expected_y = expected_traj[:, 1]
    expected_z = expected_traj[:, 2]
    output_x = output_traj[:, 0]
    output_y = output_traj[:, 1]
    output_z = output_traj[:, 2]
    ax.clear() 
    ax.plot(output_x, output_y, output_z, color='green')
    ax.plot(expected_x, expected_y, expected_z, color='red')
    ax.plot(output_traj[0][0], output_traj[0][1], output_traj[0][2], marker="o", markersize=4, markeredgecolor="green", markerfacecolor="green")
    ax.plot(expected_traj[0][0], expected_traj[0][1], expected_traj[0][2], marker="o", markersize=4, markeredgecolor="red", markerfacecolor="red")
    plt.savefig(f"{save_dir}/{idx}.png")
    idx += 1

plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions