From 25fc94bf26bd1c8f952af06ef2defe27fc18b4cf Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Tue, 18 Jul 2023 00:43:41 +0800 Subject: [PATCH 1/6] add wandb --- yarr/utils/log_writer.py | 59 ++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/yarr/utils/log_writer.py b/yarr/utils/log_writer.py index f9ccff1..6f2f85a 100644 --- a/yarr/utils/log_writer.py +++ b/yarr/utils/log_writer.py @@ -8,16 +8,19 @@ from yarr.agents.agent import ScalarSummary, HistogramSummary, ImageSummary, \ VideoSummary from torch.utils.tensorboard import SummaryWriter - +import wandb class LogWriter(object): def __init__(self, logdir: str, tensorboard_logging: bool, - csv_logging: bool): + csv_logging: bool, + wandb_logging: bool, + project_name: str = 'c2farm'): self._tensorboard_logging = tensorboard_logging self._csv_logging = csv_logging + self._wandb_logging = wandb_logging os.makedirs(logdir, exist_ok=True) if tensorboard_logging: self._tf_writer = SummaryWriter(logdir) @@ -25,36 +28,58 @@ def __init__(self, self._prev_row_data = self._row_data = OrderedDict() self._csv_file = os.path.join(logdir, 'data.csv') self._field_names = None - + if wandb_logging: + wandb.init( + project=project_name, + ) def add_scalar(self, i, name, value): if self._tensorboard_logging: self._tf_writer.add_scalar(name, value, i) if self._csv_logging: if len(self._row_data) == 0: self._row_data['step'] = i - self._row_data[name] = value.item() if isinstance( - value, torch.Tensor) else value + self._row_data[name] = value.item() if isinstance(value, torch.Tensor) else value + if self._wandb_logging: + wandb.log({name: value, 'step': i}) def add_summaries(self, i, summaries): for summary in summaries: try: - if isinstance(summary, ScalarSummary): - self.add_scalar(i, summary.name, summary.value) - elif self._tensorboard_logging: - if isinstance(summary, HistogramSummary): + if self._csv_logging and isinstance(summary, ScalarSummary): + self._row_data['step'] = i + name, value = summary.name, summary.value + self._row_data[name] = value.item() if isinstance(value, torch.Tensor) else value + + if isinstance(summary, HistogramSummary): + if self._tensorboard_logging: self._tf_writer.add_histogram( summary.name, summary.value, i) - elif isinstance(summary, ImageSummary): - # Only grab first item in batch - v = (summary.value if summary.value.ndim == 3 else - summary.value[0]) + if self._wandb_logging: + wandb.log({summary.name: wandb.Histogram(summary.value.cpu()), 'step': i}) + elif isinstance(summary, ImageSummary): + # Only grab first item in batch + + v = (summary.value if summary.value.ndim == 3 else + summary.value[0]) + if self._tensorboard_logging: self._tf_writer.add_image(summary.name, v, i) - elif isinstance(summary, VideoSummary): - # Only grab first item in batch - v = (summary.value if summary.value.ndim == 5 else - np.array([summary.value])) + if self._wandb_logging: + wandb.log({summary.name: wandb.Image(v), 'step': i}) + elif isinstance(summary, VideoSummary): + # Only grab first item in batch + v = (summary.value if summary.value.ndim == 5 else + np.array([summary.value])) + if self._tensorboard_logging: self._tf_writer.add_video( summary.name, v, i, fps=summary.fps) + if self._wandb_logging: + wandb.log({summary.name: wandb.Video(v, fps=summary.fps), 'step': i}) + elif isinstance(summary, ScalarSummary): + if self._tensorboard_logging: + self._tf_writer.add_scalar(summary.name, summary.value, i) + if self._wandb_logging: + wandb.log({summary.name: summary.value, 'step': i}) + except Exception as e: logging.error('Error on summary: %s' % summary.name) raise e From be64186ad169d9dcdfe3f4bbef2cc0bedf5e72b3 Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Tue, 18 Jul 2023 00:44:54 +0800 Subject: [PATCH 2/6] add wandb --- yarr/runners/pytorch_train_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/yarr/runners/pytorch_train_runner.py b/yarr/runners/pytorch_train_runner.py index 65c58ea..4ff0334 100644 --- a/yarr/runners/pytorch_train_runner.py +++ b/yarr/runners/pytorch_train_runner.py @@ -43,6 +43,8 @@ def __init__(self, replay_ratio: Optional[float] = None, tensorboard_logging: bool = True, csv_logging: bool = False, + wandb_logging: bool = True, + project_name: str = "c2farm", buffers_per_batch: int = -1 # -1 = all ): super(PyTorchTrainRunner, self).__init__( @@ -78,7 +80,7 @@ def __init__(self, logging.info("'logdir' was None. No logging will take place.") else: self._writer = LogWriter( - self._logdir, tensorboard_logging, csv_logging) + self._logdir, tensorboard_logging, csv_logging, wandb_logging, project_name) if weightsdir is None: logging.info( "'weightsdir' was None. No weight saving will take place.") From 6b88ef33ff06b2094bdbded51c7c24daa2d2fa2e Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Tue, 18 Jul 2023 19:03:11 +0800 Subject: [PATCH 3/6] add wandb cfg --- yarr/runners/pytorch_train_runner.py | 3 ++- yarr/utils/log_writer.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/yarr/runners/pytorch_train_runner.py b/yarr/runners/pytorch_train_runner.py index 4ff0334..1ba01bf 100644 --- a/yarr/runners/pytorch_train_runner.py +++ b/yarr/runners/pytorch_train_runner.py @@ -44,6 +44,7 @@ def __init__(self, tensorboard_logging: bool = True, csv_logging: bool = False, wandb_logging: bool = True, + wandb_cfg: dict = {}, project_name: str = "c2farm", buffers_per_batch: int = -1 # -1 = all ): @@ -80,7 +81,7 @@ def __init__(self, logging.info("'logdir' was None. No logging will take place.") else: self._writer = LogWriter( - self._logdir, tensorboard_logging, csv_logging, wandb_logging, project_name) + self._logdir, tensorboard_logging, csv_logging, wandb_logging, wandb_cfg, project_name) if weightsdir is None: logging.info( "'weightsdir' was None. No weight saving will take place.") diff --git a/yarr/utils/log_writer.py b/yarr/utils/log_writer.py index 6f2f85a..9d21893 100644 --- a/yarr/utils/log_writer.py +++ b/yarr/utils/log_writer.py @@ -17,6 +17,7 @@ def __init__(self, tensorboard_logging: bool, csv_logging: bool, wandb_logging: bool, + wandb_cfg: dict = None, project_name: str = 'c2farm'): self._tensorboard_logging = tensorboard_logging self._csv_logging = csv_logging @@ -31,6 +32,7 @@ def __init__(self, if wandb_logging: wandb.init( project=project_name, + config=wandb_cfg ) def add_scalar(self, i, name, value): if self._tensorboard_logging: From bcb9ebabd21774da9e1e576786d2044080964807 Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Thu, 20 Jul 2023 13:43:27 +0800 Subject: [PATCH 4/6] fix typo --- yarr/envs/rlbench_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarr/envs/rlbench_env.py b/yarr/envs/rlbench_env.py index 495b6b6..9acbd93 100644 --- a/yarr/envs/rlbench_env.py +++ b/yarr/envs/rlbench_env.py @@ -84,7 +84,7 @@ def _get_cam_observation_elements(camera: CameraConfig, prefix: str, channels_la ObservationElement("%s_camera_intrinsics" % prefix, (3, 3), np.float32) ) if camera.depth: - shape = img_s + [1] if schannels_last else [1] + img_s + shape = img_s + [1] if channels_last else [1] + img_s elements.append(ObservationElement("%s_depth" % prefix, shape, np.float32)) if camera.mask: raise NotImplementedError() From 00c66451cc4d2f41db388932bdbe99398264dddf Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Wed, 26 Jul 2023 09:53:52 +0800 Subject: [PATCH 5/6] fix cam --- yarr/envs/rlbench_env.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/yarr/envs/rlbench_env.py b/yarr/envs/rlbench_env.py index 9acbd93..af2096a 100644 --- a/yarr/envs/rlbench_env.py +++ b/yarr/envs/rlbench_env.py @@ -137,6 +137,11 @@ def _observation_elements( observation_config.wrist_camera, "wrist", channels_last ) ) + elements.extend( + _get_cam_observation_elements( + observation_config.overhead_camera, "overhead", channels_last + ) + ) return elements From c746de3f78cdfb361d4623060f9c6bc7db6f2113 Mon Sep 17 00:00:00 2001 From: ZHOU Bin Date: Fri, 28 Jul 2023 08:46:47 +0800 Subject: [PATCH 6/6] add exp name --- yarr/utils/log_writer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/yarr/utils/log_writer.py b/yarr/utils/log_writer.py index 9d21893..f083f2a 100644 --- a/yarr/utils/log_writer.py +++ b/yarr/utils/log_writer.py @@ -30,9 +30,16 @@ def __init__(self, self._csv_file = os.path.join(logdir, 'data.csv') self._field_names = None if wandb_logging: + try: + task_name = wandb_cfg['rlbench']['task'] + method_name = wandb_cfg['method']['name'] + exp_name = task_name + '-' + method_name + except: + exp_name = None wandb.init( project=project_name, - config=wandb_cfg + config=wandb_cfg, + name=exp_name ) def add_scalar(self, i, name, value): if self._tensorboard_logging: