diff --git a/.gitignore b/.gitignore index 3dfc3ed..cf1e919 100644 --- a/.gitignore +++ b/.gitignore @@ -16,10 +16,5 @@ training/weights/*.pth training/config/dataset.yaml training/runs training/dataset/ - -#ファイル名で指定 -.codex -.vscode -__pycache__ - -#例外ファイルを指定 \ No newline at end of file +deployment/config/topomap/ +deployment/weights/*.pt diff --git a/OmniVLA b/OmniVLA index 5182600..eb32dea 160000 --- a/OmniVLA +++ b/OmniVLA @@ -1 +1 @@ -Subproject commit 5182600cb4a9ee07684e17cdd2a6cbafc56b8a68 +Subproject commit eb32dea11e3525cd440a2b65a4cc559220053841 diff --git a/deployment/config/nav.yaml b/deployment/config/nav.yaml index c1c8ccc..bd72e57 100644 --- a/deployment/config/nav.yaml +++ b/deployment/config/nav.yaml @@ -27,3 +27,11 @@ goal_pose: [10.0, -1.0, 0.0, -1.0] goal_image_path: OmniVLA/inference/goal_img.jpg lan_prompt: blue trash bin path_frame_id: base_link + +# Topological navigation using ml_planner-style topomap features. +use_toponav: false +topomap_path: config/topomap/topomap.yaml +topomap_image_dir: config/topomap/images +placenet_weight_path: deployment/weights/placenet.pt +toponav_crop_size: 288 +toponav_min_score: -1.0 diff --git a/deployment/navvla/navigation.py b/deployment/navvla/navigation.py index 4cb2388..923b4f9 100644 --- a/deployment/navvla/navigation.py +++ b/deployment/navvla/navigation.py @@ -4,6 +4,7 @@ import argparse import math +import sys from collections import deque from pathlib import Path from typing import Deque, Optional, Tuple @@ -15,11 +16,26 @@ import torch from PIL import Image as PILImage +_THIS_FILE = Path(__file__).resolve() +_REPO_ROOT_CANDIDATES = [ + _THIS_FILE.parents[2], + _THIS_FILE.parents[4] / "src" / "NavVLA" if len(_THIS_FILE.parents) > 4 else None, +] +for _repo_root in reversed([path for path in _REPO_ROOT_CANDIDATES if path is not None and (path / "OmniVLA").exists()]): + for _path in (_repo_root, _repo_root / "OmniVLA", _repo_root / "OmniVLA" / "inference"): + if str(_path) not in sys.path: + sys.path.insert(0, str(_path)) + +_INSTALLED_INFERENCE_DIR = _THIS_FILE.parents[1] / "OmniVLA" / "inference" +if _INSTALLED_INFERENCE_DIR.exists() and str(_INSTALLED_INFERENCE_DIR) not in sys.path: + sys.path.insert(0, str(_INSTALLED_INFERENCE_DIR)) + from OmniVLA.inference.utils_policy import ( load_model, transform_images_PIL_mask, ) -from .preprocess import build_mask, build_omnivla_edge_inputs, image_to_cv2, load_yaml +from .preprocess import build_mask, build_omnivla_edge_inputs, image_msg_to_bgr, image_to_cv2, load_yaml +from .toponav import TopologicalNavigator import rclpy from geometry_msgs.msg import PoseStamped, Twist @@ -41,6 +57,7 @@ def __init__( self.autonomous_flag = False self.context_queue = [] self.obs_image = None + self.obs_image_bgr = None self.package_share_dir = package_share_dir self.nav_cfg = load_yaml(nav_config_path) @@ -49,6 +66,7 @@ def __init__( self.init_params() self.init_model() self.init_model_modality() + self.init_toponav() self.image_sub = self.create_subscription(Image, "/image_raw", self.image_callback, 10) self.autonomous_sub = self.create_subscription(Bool, "/autonomous", self.autonomous_callback, 10) @@ -61,7 +79,9 @@ def __init__( def init_params(self) -> None: self.context_size = self.nav_cfg.get("context_size", 5) - self.waypoint_spacing = self.nav_cfg.get("metric_waypoint_spacing", 0.1) + self.metric_waypoint_spacing = self.nav_cfg.get("metric_waypoint_spacing", 0.1) + self.waypoint_spacing = self.nav_cfg.get("waypoint_spacing", 1) + self.action_scale = self.metric_waypoint_spacing * self.waypoint_spacing self.waypoint_select = self.nav_cfg.get("waypoint_select", 4) self.linear_max_vel = self.nav_cfg.get("linear_max_vel", 0.3) self.angular_max_vel = self.nav_cfg.get("angular_max_vel", 0.3) @@ -151,7 +171,46 @@ def init_model_modality(self) -> None: self.goal_image_tensor = transform_images_PIL_mask(goal_pil, self.mask_goal).to(self.device) self.goal_pose_tensor = torch.tensor([goal_pose], dtype=torch.float32, device=self.device) self.modality_tensor = torch.tensor([self.modality_id], dtype=torch.long, device=self.device) - + + def _update_text_feature(self) -> None: + prompt = self.latest_prompt if self.use_prompt else "No language instruction" + token = clip.tokenize(prompt, truncate=True).to(self.device) + with torch.no_grad(): + self.feat_text = self.text_encoder.encode_text(token) + + def resolve_package_path(self, raw_path: str) -> Path: + path = Path(raw_path) + return path if path.is_absolute() else self.package_share_dir / path + + def init_toponav(self) -> None: + self.use_toponav = bool(self.nav_cfg.get("use_toponav", False)) + self.toponav = None + self.toponav_current_index = None + self.toponav_goal_index = None + self.toponav_min_score = float(self.nav_cfg.get("toponav_min_score", -1.0)) + + if not self.use_toponav: + return + + if not self.use_goal_image: + raise ValueError("Toponav requires a modality_id that uses goal_image.") + + topomap_path = self.resolve_package_path(str(self.nav_cfg.get("topomap_path", "config/topomap/topomap.yaml"))) + image_dir = self.resolve_package_path(str(self.nav_cfg.get("topomap_image_dir", "config/topomap/images"))) + weight_path = self.resolve_package_path(str(self.nav_cfg.get("placenet_weight_path", "deployment/weights/placenet.pt"))) + + self.toponav = TopologicalNavigator( + topomap_path=topomap_path, + image_dir=image_dir, + weight_path=weight_path, + device=self.device, + image_size=self.goal_size, + crop_size=int(self.nav_cfg.get("toponav_crop_size", 288)), + delta=float(self.nav_cfg.get("toponav_delta", 5.0)), + window_lower=int(self.nav_cfg.get("toponav_window_lower", -1)), + window_upper=int(self.nav_cfg.get("toponav_window_upper", 10)), + ) + self.get_logger().info(f"Toponav loaded: nodes={len(self.toponav.nodes)}, topomap={topomap_path}") def autonomous_callback(self, msg: Bool) -> None: self.autonomous_flag = bool(msg.data) @@ -162,6 +221,8 @@ def prompt_callback(self, msg: String) -> None: self._update_text_feature() def image_callback(self, msg: Image) -> None: + self.obs_image_bgr = image_msg_to_bgr(msg) + cv_image = image_to_cv2(msg, self.clip_size) self.obs_image = PILImage.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) @@ -176,6 +237,8 @@ def timer_callback(self) -> None: if len(self.context_queue) < self.context_size + 1: return + self.update_toponav_goal() + obs_images, map_images, cur_large_img = build_omnivla_edge_inputs( context_queue=self.context_queue, current_image=self.obs_image, @@ -208,6 +271,31 @@ def timer_callback(self) -> None: self.publisher_path(waypoints) self.publisher_command_velocity(linear_vel, angular_vel) + def update_toponav_goal(self) -> None: + if self.toponav is None or self.obs_image_bgr is None: + return + + current_index, score = self.toponav.estimate_current_node(self.obs_image_bgr) + if score < self.toponav_min_score: + self.get_logger().warn( + f"Toponav score below threshold: score={score:.3f}, threshold={self.toponav_min_score:.3f}" + ) + return + + goal_index = self.toponav.select_goal_node(current_index) + if current_index == self.toponav_current_index and goal_index == self.toponav_goal_index: + return + + goal_pil = self.toponav.load_goal_image(goal_index) + self.goal_image_tensor = transform_images_PIL_mask(goal_pil, self.mask_goal).to(self.device) + self.toponav_current_index = current_index + self.toponav_goal_index = goal_index + current_node = self.toponav.nodes[current_index] + goal_node = self.toponav.nodes[goal_index] + self.get_logger().info( + "Toponav goal updated: " + f"current_id={current_node.node_id}, goal_id={goal_node.node_id}, score={score:.3f}" + ) def publisher_path(self, waypoints: np.ndarray) -> None: msg = NavPath() @@ -217,8 +305,8 @@ def publisher_path(self, waypoints: np.ndarray) -> None: for wp in waypoints: pose = PoseStamped() pose.header = msg.header - x = float(wp[0]) * self.waypoint_spacing - y = float(wp[1]) * self.waypoint_spacing + x = float(wp[0]) * self.action_scale + y = float(wp[1]) * self.action_scale yaw = math.atan2(float(wp[3]), float(wp[2])) pose.pose.position.x = x @@ -241,8 +329,8 @@ def action_to_waypoints_and_cmd_vel(self, action_pred: np.ndarray) -> Tuple[np.n selected = max(0, min(self.waypoint_select, waypoints.shape[0] - 1)) dx, dy, hx, hy = [float(v) for v in waypoints[selected]] - dx *= self.waypoint_spacing - dy *= self.waypoint_spacing + dx *= self.action_scale + dy *= self.action_scale eps = 1e-8 dt = 1.0 / 3.0 @@ -257,9 +345,6 @@ def action_to_waypoints_and_cmd_vel(self, action_pred: np.ndarray) -> Tuple[np.n linear_vel = dx / dt angular_vel = math.atan(dy / dx) / dt - linear_vel = float(np.clip(linear_vel, 0.0, 0.5)) - angular_vel = float(np.clip(angular_vel, -1.0, 1.0)) - maxv = float(self.linear_max_vel) maxw = float(self.angular_max_vel) if abs(linear_vel) <= maxv: diff --git a/deployment/navvla/preprocess.py b/deployment/navvla/preprocess.py index d7753bf..5f60672 100644 --- a/deployment/navvla/preprocess.py +++ b/deployment/navvla/preprocess.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import sys from typing import TYPE_CHECKING from typing import Dict, List, Tuple @@ -10,6 +11,16 @@ import yaml from PIL import Image as PILImage +_THIS_FILE = Path(__file__).resolve() +_REPO_ROOT_CANDIDATES = [ + _THIS_FILE.parents[2], + _THIS_FILE.parents[4] / "src" / "NavVLA" if len(_THIS_FILE.parents) > 4 else None, +] +for _repo_root in reversed([path for path in _REPO_ROOT_CANDIDATES if path is not None and (path / "OmniVLA").exists()]): + for _path in (_repo_root, _repo_root / "OmniVLA", _repo_root / "OmniVLA" / "inference"): + if str(_path) not in sys.path: + sys.path.insert(0, str(_path)) + from OmniVLA.inference.utils_policy import transform_images_PIL_mask, transform_images_map if TYPE_CHECKING: @@ -45,9 +56,44 @@ def build_mask(size: Tuple[int, int], use_mask: bool, mask_path: str) -> np.ndar return loaded.astype(np.float32) +def image_msg_to_bgr(msg: "Image") -> np.ndarray: + encoding = msg.encoding.lower() + channels_by_encoding = { + "bgr8": 3, + "rgb8": 3, + "bgra8": 4, + "rgba8": 4, + "mono8": 1, + "8uc1": 1, + "8uc3": 3, + "8uc4": 4, + "yuv422_yuy2": 2, + "yuyv": 2, + "yuy2": 2, + } + if encoding not in channels_by_encoding: + raise ValueError(f"Unsupported image encoding: {msg.encoding}") + + channels = channels_by_encoding[encoding] + row = np.frombuffer(msg.data, dtype=np.uint8).reshape(int(msg.height), int(msg.step)) + image = row[:, : int(msg.width) * channels].reshape(int(msg.height), int(msg.width), channels) + + if encoding in ("bgr8", "8uc3"): + return image.copy() + if encoding == "rgb8": + return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if encoding == "bgra8": + return cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + if encoding == "rgba8": + return cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) + if encoding in ("yuv422_yuy2", "yuyv", "yuy2"): + return cv2.cvtColor(image, cv2.COLOR_YUV2BGR_YUY2) + + return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + + def image_to_cv2(msg: "Image", output_size: Tuple[int, int]) -> np.ndarray: - frame = np.frombuffer(msg.data, dtype=np.uint8).reshape((int(msg.height), int(msg.width), 3)) - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + frame = image_msg_to_bgr(msg) side = min(frame.shape[0], frame.shape[1]) offset_y = (frame.shape[0] - side) // 2 diff --git a/deployment/navvla/toponav.py b/deployment/navvla/toponav.py new file mode 100644 index 0000000..dc2b49c --- /dev/null +++ b/deployment/navvla/toponav.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import yaml +from PIL import Image as PILImage +from torchvision import transforms + + +@dataclass(frozen=True) +class TopologicalNode: + node_id: int + image_name: str + feature: np.ndarray + edges: Tuple[Dict[str, object], ...] + + +class TopologicalNavigator: + def __init__( + self, + topomap_path: Path, + image_dir: Path, + weight_path: Path, + device: torch.device, + image_size: Tuple[int, int], + crop_size: int = 288, + delta: float = 5.0, + window_lower: int = -1, + window_upper: int = 10, + ) -> None: + self.topomap_path = Path(topomap_path) + self.image_dir = Path(image_dir) + self.weight_path = Path(weight_path) + self.device = device + self.image_size = (int(image_size[0]), int(image_size[1])) + self.crop_size = int(crop_size) + + # Bayesian filter パラメータ + self.delta = float(delta) + self.window_lower = int(window_lower) + self.window_upper = int(window_upper) + # 遷移モデル:ウィンドウ内の移動を等確率とする一様分布 + self.transition = np.ones(self.window_upper - self.window_lower, dtype=np.float32) + + # 信念・lambdaは最初のフレームで初期化 + self.belief: Optional[np.ndarray] = None + self.lambda1: float = 0.0 + + self.nodes = self._load_topomap(self.topomap_path) + self.node_index_by_id = {node.node_id: idx for idx, node in enumerate(self.nodes)} + self.feature_matrix = np.stack([node.feature for node in self.nodes], axis=0) + self.model = torch.jit.load(str(self.weight_path), map_location=self.device).eval() + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize((85, 85), antialias=True), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + @staticmethod + def _load_topomap(topomap_path: Path) -> List[TopologicalNode]: + if not topomap_path.exists(): + raise FileNotFoundError(f"Topomap file not found: {topomap_path}") + + with topomap_path.open("r", encoding="utf-8") as f: + topomap = yaml.safe_load(f) or {} + + raw_nodes = topomap.get("nodes", []) + if not raw_nodes: + raise ValueError(f"No nodes found in topomap: {topomap_path}") + + nodes = [] + for raw_node in raw_nodes: + feature = np.asarray(raw_node["feature"], dtype=np.float32).reshape(-1) + norm = float(np.linalg.norm(feature)) + if norm <= 1e-8: + raise ValueError(f"Topomap feature has zero norm: node_id={raw_node.get('id')}") + feature = feature / norm + + edges = tuple(raw_node.get("edges", [])) + if not edges: + raise ValueError(f"Topomap node must have at least one edge: node_id={raw_node.get('id')}") + + nodes.append( + TopologicalNode( + node_id=int(raw_node["id"]), + image_name=str(raw_node["image"]), + feature=feature, + edges=edges, + ) + ) + return nodes + + def _center_crop(self, image_bgr: np.ndarray) -> np.ndarray: + height, width = image_bgr.shape[:2] + side = min(height, width, self.crop_size) + top = (height - side) // 2 + left = (width - side) // 2 + return image_bgr[top : top + side, left : left + side] + + def extract_feature(self, image_bgr: np.ndarray) -> np.ndarray: + cropped = self._center_crop(image_bgr) + image_rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) + image_tensor = self.transform(image_rgb).unsqueeze(0).to(self.device, dtype=torch.float32) + with torch.no_grad(): + feature = self.model(image_tensor) + + feature_np = feature.squeeze(0).detach().cpu().numpy().reshape(-1).astype(np.float32) + norm = float(np.linalg.norm(feature_np)) + if norm <= 1e-8: + raise RuntimeError("PlaceNet returned a zero-norm feature.") + return feature_np / norm + + def _compute_distances(self, query_feature: np.ndarray) -> np.ndarray: + # L2正規化済みのためdot積=コサイン類似度 → コサイン距離に変換 + dots = np.clip(np.dot(self.feature_matrix, query_feature), -1.0, 1.0) + return np.sqrt(2.0 - 2.0 * dots) + + def _observation_likelihood(self, query_feature: np.ndarray) -> np.ndarray: + return np.exp(-self.lambda1 * self._compute_distances(query_feature)) + + def _initialize_belief(self, query_feature: np.ndarray) -> None: + dists = self._compute_distances(query_feature) + descriptor_quantiles = np.quantile(dists, [0.025, 0.975]) + self.lambda1 = np.log(self.delta) / (descriptor_quantiles[1] - descriptor_quantiles[0]) + self.belief = np.exp(-self.lambda1 * dists) + self.belief /= self.belief.sum() + + def _update_belief(self, query_feature: np.ndarray) -> None: + # ===== Prediction ステップ:遷移モデルで信念を前進方向に伝播 ===== + if self.window_lower < 0: + conv_ind_l = abs(self.window_lower) + conv_ind_h = len(self.belief) + abs(self.window_lower) + bel_ind_l, bel_ind_h = 0, len(self.belief) + else: + conv_ind_l, conv_ind_h = 0, len(self.belief) - self.window_lower + bel_ind_l, bel_ind_h = self.window_lower, len(self.belief) + + belief_pad = np.pad(self.belief, len(self.transition) - 1, mode="symmetric") + conv = np.convolve(belief_pad, self.transition, mode="valid") + self.belief[bel_ind_l:bel_ind_h] = conv[conv_ind_l:conv_ind_h] + + if self.window_lower > 0: + self.belief[: self.window_lower] = 0.0 + + # ===== Measurement ステップ:観測尤度でベイズ更新 ===== + self.belief *= self._observation_likelihood(query_feature) + self.belief /= self.belief.sum() + + def estimate_current_node(self, image_bgr: np.ndarray) -> Tuple[int, float]: + query_feature = self.extract_feature(image_bgr) + + if self.belief is None: + self._initialize_belief(query_feature) + else: + self._update_belief(query_feature) + + best_index = int(np.argmax(self.belief)) + return best_index, float(self.belief[best_index]) + + def reset(self) -> None: + """信念を初期化する。環境が大きく変わった場合や再スタート時に呼ぶ。""" + self.belief = None + self.lambda1 = 0.0 + + def select_goal_node(self, current_index: int) -> int: + current_node = self.nodes[current_index] + target_id = int(current_node.edges[0].get("target", current_node.node_id)) + return self.node_index_by_id.get(target_id, current_index) + + def load_goal_image(self, node_index: int) -> PILImage.Image: + image_path = self.image_dir / self.nodes[node_index].image_name + if not image_path.exists(): + raise FileNotFoundError(f"Topomap image not found: {image_path}") + return PILImage.open(image_path).convert("RGB").resize(self.image_size) diff --git a/deployment/scripts/create_topomap.py b/deployment/scripts/create_topomap.py new file mode 100644 index 0000000..8404d42 --- /dev/null +++ b/deployment/scripts/create_topomap.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Optional, Sequence + +import cv2 +import numpy as np +import torch +import yaml +from torchvision import transforms + + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +class TopomapGenerator: + def __init__( + self, + dataset_path: Path, + output_dir: Path, + weight_path: Path, + device: torch.device, + saved_step: int = 10, + crop_size: int = 288, + ) -> None: + self.dataset_path = Path(dataset_path) + self.output_dir = Path(output_dir) + self.output_image_dir = self.output_dir / "images" + self.topomap_path = self.output_dir / "topomap.yaml" + self.weight_path = Path(weight_path) + self.device = device + self.saved_step = int(saved_step) + self.crop_size = int(crop_size) + if self.saved_step < 1: + raise ValueError("saved_step must be >= 1") + + self.model = torch.jit.load(str(self.weight_path), map_location=self.device).eval() + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize((85, 85), antialias=True), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + @staticmethod + def _image_sort_key(path: Path) -> tuple[int, str]: + try: + return int(path.stem), path.name + except ValueError: + return 0, path.name + + def _load_trajectory_names(self) -> list[str]: + traj_names_path = self.dataset_path / "traj_names.txt" + if not traj_names_path.exists(): + return [path.name for path in sorted(self.dataset_path.glob("traj_*")) if path.is_dir()] + + with traj_names_path.open("r", encoding="utf-8") as f: + return [line.strip() for line in f if line.strip()] + + def _load_trajectory(self, traj_name: str) -> list[Path]: + traj_dir = self.dataset_path / traj_name + if not traj_dir.is_dir(): + raise FileNotFoundError(f"Trajectory directory not found: {traj_dir}") + + image_paths = sorted( + list(traj_dir.glob("*.jpg")) + list(traj_dir.glob("*.png")), + key=self._image_sort_key, + ) + if not image_paths: + raise ValueError(f"No images found in trajectory: {traj_dir}") + + return image_paths + + def _preprocess_image(self, image_path: Path) -> np.ndarray: + image = cv2.imread(str(image_path), cv2.IMREAD_COLOR) + if image is None: + raise ValueError(f"Failed to read image: {image_path}") + + height, width = image.shape[:2] + side = min(height, width, self.crop_size) + top = (height - side) // 2 + left = (width - side) // 2 + return image[top : top + side, left : left + side] + + def _extract_feature(self, image_bgr: np.ndarray) -> Sequence[float]: + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + image_tensor = self.transform(image_rgb).unsqueeze(0).to(self.device, dtype=torch.float32) + with torch.no_grad(): + feature = self.model(image_tensor) + + feature_np = feature.squeeze(0).detach().cpu().numpy().reshape(-1).astype(np.float32) + norm = float(np.linalg.norm(feature_np)) + if norm > 1e-8: + feature_np = feature_np / norm + return feature_np.tolist() + + def generate(self) -> Path: + self.output_image_dir.mkdir(parents=True, exist_ok=True) + traj_names = self._load_trajectory_names() + if not traj_names: + raise ValueError(f"No trajectories found in dataset: {self.dataset_path}") + + nodes = [] + for traj_name in traj_names: + image_paths = self._load_trajectory(traj_name) + for image_path in image_paths[:: self.saved_step]: + node_index = len(nodes) + frame_index = self._image_sort_key(image_path)[0] + + cropped_image = self._preprocess_image(image_path) + + # 保存用は cv2.resize で 85x85 に縮小 + save_image = cv2.resize(cropped_image, (85, 85), interpolation=cv2.INTER_AREA) + output_image_name = f"img{node_index + 1:05d}.png" + cv2.imwrite(str(self.output_image_dir / output_image_name), save_image) + + node = { + "id": node_index, + "image": output_image_name, + "feature": self._extract_feature(cropped_image), + "source": { + "trajectory": traj_name, + "frame": frame_index, + "image": str(image_path.relative_to(self.dataset_path)), + }, + } + nodes.append(node) + + for node_index, node in enumerate(nodes): + target = node_index + 1 if node_index + 1 < len(nodes) else node_index + node["edges"] = [{"target": target}] + + with self.topomap_path.open("w", encoding="utf-8") as f: + yaml.safe_dump({"nodes": nodes}, f, sort_keys=False, allow_unicode=False) + return self.topomap_path + + +def resolve_cli_path(raw_path: str, base_path: Path = Path.cwd()) -> Path: + path = Path(raw_path).expanduser() + return path if path.is_absolute() else base_path / path + + +def main(args: Optional[list[str]] = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("dataset_path", help="Dataset directory generated by NavVLA data_collection.py") + parser.add_argument( + "--output-dir", + default="deployment/config/topomap", + help="Directory where topomap.yaml and images/ are written, relative to the NavVLA repository root", + ) + parser.add_argument( + "--weights", + default="deployment/weights/placenet.pt", + help="PlaceNet TorchScript weight path, relative to the NavVLA repository root", + ) + parser.add_argument("--saved-step", type=int, default=10, help="Use every Nth dataset image as a node") + parser.add_argument("--crop-size", type=int, default=288, help="Center crop size before resizing to 85x85") + parsed = parser.parse_args(args) + + dataset_path = resolve_cli_path(parsed.dataset_path) + output_dir = resolve_cli_path(parsed.output_dir, REPO_ROOT) + weights_path = resolve_cli_path(parsed.weights, REPO_ROOT) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + generator = TopomapGenerator( + dataset_path=dataset_path, + output_dir=output_dir, + weight_path=weights_path, + device=device, + saved_step=parsed.saved_step, + crop_size=parsed.crop_size, + ) + topomap_path = generator.generate() + print(f"Topomap saved: {topomap_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/deployment/scripts/data_collection.py b/deployment/scripts/data_collection.py new file mode 100644 index 0000000..31bda34 --- /dev/null +++ b/deployment/scripts/data_collection.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 + +import math +import pickle +import time +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import rclpy +from cv_bridge import CvBridge +from nav_msgs.msg import Odometry +from rclpy.node import Node +from rclpy.qos import qos_profile_sensor_data +from sensor_msgs.msg import Image +from std_msgs.msg import Empty +import shutil + +SAMPLE_INTERVAL = 0.1 # 10 Hz + +# data_collection/data_collection.py -> parents[1] = NavVLA/ +_DEFAULT_SAVE_DIR = str(Path(__file__).resolve().parents[1] / 'training' / 'dataset') + + +class DataCollectionNode(Node): + """ + 単眼カメラ(USBカメラ)とMid-360(FAST-LIO)のオドメトリを記録するノード。 + /flag トピック(ジョイスティックAボタン等)で開始/停止を切り替える。 + 収集データは NavVLA の EdgeNavigationDataset が読める形式で保存される。 + """ + + def __init__(self) -> None: + super().__init__('data_collection_node') + self.bridge = CvBridge() + + self.declare_parameter('save_dir', _DEFAULT_SAVE_DIR) + self.save_dir = Path( + self.get_parameter('save_dir').get_parameter_value().string_value + ) + + # Training dataset parameters — must match training/config/network.yaml and dataset.yaml. + # Used to compute the minimum trajectory length that yields at least one trainable sample. + # Formula (from training/data/dataset.py): + # min_frames = (context_size + len_traj_pred) * waypoint_spacing + end_slack + 1 + self.declare_parameter('context_size', 5) + self.declare_parameter('len_traj_pred', 8) + self.declare_parameter('waypoint_spacing', 1) + self.declare_parameter('end_slack', 3) + + context_size = self.get_parameter('context_size').get_parameter_value().integer_value + len_traj_pred = self.get_parameter('len_traj_pred').get_parameter_value().integer_value + waypoint_spacing = self.get_parameter('waypoint_spacing').get_parameter_value().integer_value + end_slack = self.get_parameter('end_slack').get_parameter_value().integer_value + self.min_frames = (context_size + len_traj_pred) * waypoint_spacing + end_slack + 1 + + self.raw_data_buffer = [] + self.frame_count = 0 + self.temp_dir = None + self.chunk_size = 1000 + self.dataset_dir: Optional[Path] = None + self.traj_count = 0 + self.traj_names = [] + + self.latest_image: Optional[np.ndarray] = None + self.latest_pose = np.array([0.0, 0.0, 0.0], dtype=np.float32) + + self.is_recording = False + self._odom_received = False + self._img_received = False + + self.create_subscription( + Image, + '/image_raw', + self._image_callback, + qos_profile_sensor_data, + ) + self.create_subscription( + Odometry, + '/Odometry', + self._odom_callback, + qos_profile_sensor_data, + ) + self.create_subscription( + Empty, + '/flag', + self._flag_callback, + 10, + ) + + self.create_timer(SAMPLE_INTERVAL, self._timer_callback) + + self.get_logger().info('==========================================') + self.get_logger().info('Data Collection Node Started') + self.get_logger().info('Waiting for /image_raw and /Odometry...') + self.get_logger().info('Publish to /flag to START/STOP recording') + self.get_logger().info(f'Save dir: {self.save_dir}') + self.get_logger().info(f'Min frames required: {self.min_frames}') + self.get_logger().info('==========================================') + + def _image_callback(self, msg: Image) -> None: + try: + cv_img = self.bridge.imgmsg_to_cv2(msg, 'bgr8') + + h, w = cv_img.shape[:2] + crop_size = min(h, w) + x_start = (w - crop_size) // 2 + y_start = (h - crop_size) // 2 + cropped_img = cv_img[y_start:y_start + crop_size, x_start:x_start + crop_size] + self.latest_image = cv2.resize(cropped_img, (224, 224), interpolation=cv2.INTER_LINEAR) + if not self._img_received: + self._img_received = True + self.get_logger().info('Camera image received.') + except Exception as e: + self.get_logger().error(f'Failed to convert image: {e}') + + + def _odom_callback(self, msg: Odometry) -> None: + pos = msg.pose.pose.position + q = msg.pose.pose.orientation + self.latest_pose = np.array([pos.x, pos.y, q.w, q.x, q.y, q.z], dtype=np.float32) + if not self._odom_received: + self._odom_received = True + self.get_logger().info('Odometry received.') + + + def _flag_callback(self, _msg: Empty) -> None: + self.is_recording = not self.is_recording + status = 'STARTED' if self.is_recording else 'PAUSED' + + if self.is_recording: + self.raw_data_buffer = [] + if self.dataset_dir is None: + timestamp_str = time.strftime('%Y%m%d_%H%M%S') + dataset_name = f'navvla_{timestamp_str}' + self.dataset_dir = self.save_dir / dataset_name + self.dataset_dir.mkdir(parents=True, exist_ok=True) + self.frame_count = 0 + self.traj_count = 0 + self.traj_names = [] + self.temp_dir = Path('/tmp') / f'navvla_recording_{timestamp_str}_{self.traj_count}' + self.temp_dir.mkdir(parents=True, exist_ok=True) + self.get_logger().info(f'🟢 recording {status}. Temp dir: {self.temp_dir}') + else: + self._reset_temp_dir() + self.get_logger().info( + f'🟢 recording RESUMED. Dataset: {self.dataset_dir}. ' + f'Temp dir: {self.temp_dir}' + ) + else: + self._flush_buffer(final=True) + self.get_logger().info(f' ⏸️ Recording {status}. Buffer size: {len(self.raw_data_buffer)}') + + + def _timer_callback(self) -> None: + if not self.is_recording: + return + if self.latest_image is None or not self._odom_received: + return + + x, y, qw, qx, qy, qz = self.latest_pose + yaw = math.atan2( + 2.0 * (qw * qz + qx * qy), + 1.0 - 2.0 * (qy * qy + qz * qz), + ) + + pose_with_yaw = np.array([x, y, yaw], dtype=np.float32) + chunk_frame_id = len(self.raw_data_buffer) + + self.raw_data_buffer.append({ + 'frame_id' : chunk_frame_id, + 'pose': pose_with_yaw.copy(), # 12 bytes + }) + + if self.temp_dir is not None: + ok = cv2.imwrite( + str(self.temp_dir / f'{chunk_frame_id:06d}.jpg'), + self.latest_image + ) + if not ok: + self.raw_data_buffer.pop() # pose と画像の対応を保つ + self.get_logger().error( + f'❌Failed to write image {chunk_frame_id:06d}.jpg — frame dropped.' + ) + return + + self.frame_count += 1 + + if len(self.raw_data_buffer) >= self.chunk_size: + self._flush_buffer(final=False) + + if self.frame_count % 10 == 0: + self.get_logger().info( + f'📊 Collected {self.frame_count} frames. ' + f'Buffer: {len(self.raw_data_buffer)} / {self.chunk_size}' + ) + + def _reset_temp_dir(self) -> None: + if self.temp_dir is not None and self.temp_dir.exists(): + try: + shutil.rmtree(self.temp_dir) + except Exception as e: + self.get_logger().error(f'❌Failed to clean temp dir: {e}') + timestamp_str = time.strftime('%Y%m%d_%H%M%S') + self.temp_dir = Path('/tmp') / f'navvla_recording_{timestamp_str}_{self.traj_count}' + self.temp_dir.mkdir(parents=True, exist_ok=True) + + def _flush_buffer(self, final: bool) -> None: + num_samples = len(self.raw_data_buffer) + if num_samples == 0: + return + + if final and num_samples < self.min_frames: + self.get_logger().warn( + f'⚠️Not enough data ({num_samples} frames). Need at least {self.min_frames}.' + ) + self.raw_data_buffer = [] + if self.temp_dir is not None and self.temp_dir.exists(): + try: + shutil.rmtree(self.temp_dir) + except Exception as e: + self.get_logger().error(f'❌Failed to clean temp dir: {e}') + self.temp_dir = None + return + + if not final and num_samples < self.chunk_size: + return + + if self.dataset_dir is None: + timestamp_str = time.strftime('%Y%m%d_%H%M%S') + self.dataset_dir = self.save_dir / f'navvla_{timestamp_str}' + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + traj_name = f'traj_{self.traj_count}' + traj_dir = self.dataset_dir / traj_name + + try: + traj_dir.mkdir(parents=True, exist_ok=True) + except Exception as e: + self.get_logger().error(f'❌Failed to create directory: {e}') + return + + positions = [] + yaws = [] + + for out_index, metadata in enumerate(self.raw_data_buffer): + pose = metadata['pose'] + positions.append([pose[0], pose[1]]) + yaws.append(pose[2]) + if self.temp_dir is not None: + src_path = self.temp_dir / f"{metadata['frame_id']:06d}.jpg" + dst_path = traj_dir / f'{out_index}.jpg' + if src_path.exists(): + shutil.move(str(src_path), str(dst_path)) + else: + self.get_logger().error(f'❌Missing temp image: {src_path}') + return + + traj_data = { + 'position': np.array(positions, dtype=np.float32), + 'yaw': np.array(yaws, dtype=np.float32), + } + + try: + with open(traj_dir / 'traj_data.pkl', 'wb') as f: + pickle.dump(traj_data, f) + + self.traj_names.append(traj_name) + with open(self.dataset_dir / 'traj_names.txt', 'w') as f: + f.write(''.join(f'{name}\n' for name in self.traj_names)) + + self.get_logger().info(f'💾Saved {num_samples} frames to: {traj_dir}') + except Exception as e: + self.get_logger().error(f'❌Failed to save data: {e}') + return + + self.traj_count += 1 + self.raw_data_buffer = [] + if final: + if self.temp_dir is not None and self.temp_dir.exists(): + try: + shutil.rmtree(self.temp_dir) + except Exception as e: + self.get_logger().error(f'❌Failed to clean temp dir: {e}') + self.temp_dir = None + else: + self._reset_temp_dir() + + def save_data(self) -> None: + if self.frame_count == 0 and not self.traj_names: + self.get_logger().warn('⚠️No data collected. Nothing to save.') + return + + self._flush_buffer(final=True) + + if self.dataset_dir is not None: + self.get_logger().info(f'Saved dataset to: {self.dataset_dir}') + + if self.temp_dir is not None and self.temp_dir.exists(): + try: + shutil.rmtree(self.temp_dir) + self.get_logger().info(f'🧹 Cleaned up temp dir: {self.temp_dir}') + except Exception as e: + self.get_logger().error(f'❌Failed to clean temp dir: {e}') + +def main(args=None): + rclpy.init(args=args) + node = DataCollectionNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + node.get_logger().info('Finishing data collection...') + finally: + node.save_data() + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/deployment/setup.py b/deployment/setup.py index 554f4fe..39cb633 100644 --- a/deployment/setup.py +++ b/deployment/setup.py @@ -5,33 +5,29 @@ package_name = "navvla" deployment_dir = Path(__file__).resolve().parent -omnivla_dir = (deployment_dir.parent / "OmniVLA").resolve() deployment_weight_files = [ str(path.relative_to(deployment_dir)) for path in sorted((deployment_dir / "weights").glob("*")) if path.is_file() ] +topomap_files = [ + str(path.relative_to(deployment_dir)) + for path in sorted((deployment_dir / "config" / "topomap").glob("*.yaml")) + if path.is_file() +] +topomap_image_files = [ + str(path.relative_to(deployment_dir)) + for path in sorted((deployment_dir / "config" / "topomap" / "images").glob("*.png")) + if path.is_file() +] navvla_packages = find_packages(include=["navvla", "navvla.*"]) -omnivla_packages = ["OmniVLA", "OmniVLA.inference"] + find_packages( - where=str(omnivla_dir), - include=["prismatic", "prismatic.*"], -) setup( name=package_name, version="0.1.0", - packages=navvla_packages + omnivla_packages, - package_dir={ - "OmniVLA": str(omnivla_dir), - "OmniVLA.inference": str(omnivla_dir / "inference"), - "prismatic": str(omnivla_dir / "prismatic"), - }, - package_data={ - "OmniVLA.inference": ["*.jpg"], - "prismatic": ["py.typed", "vla/datasets/data_config.yaml"], - }, + packages=navvla_packages, data_files=[ ( "share/ament_index/resource_index/packages", @@ -54,6 +50,14 @@ f"share/{package_name}/deployment/weights", deployment_weight_files, ), + ( + f"share/{package_name}/config/topomap", + topomap_files, + ), + ( + f"share/{package_name}/config/topomap/images", + topomap_image_files, + ), ], install_requires=["setuptools", "numpy", "PyYAML"], zip_safe=True, @@ -61,10 +65,10 @@ maintainer_email="s21c1135sc@s.chibakoudai.jp", description="NavVLA wrappers for OmniVLA training and ROS2 deployment.", license="MIT", - tests_require=["pytest"], entry_points={ "console_scripts": [ "navigation_node = navvla.navigation:main", + "create_topomap = navvla.create_topomap:main", ], }, ) diff --git a/train.py b/train.py index 1d9c00a..80e3f9f 100644 --- a/train.py +++ b/train.py @@ -40,32 +40,57 @@ def create_dataloaders( missing = [key for key in required_data_keys if key not in data_config] if missing: raise ValueError(f"Missing required keys for dataset {dataset_name}: {missing}") - for data_split_type in ("train", "test"): - if data_split_type not in data_config: - continue - dataset = EdgeNavigationDataset( - data_folder=data_config["data_folder"], - data_split_folder=data_config[data_split_type], - dataset_name=str(dataset_name), - image_size=tuple(dataset_cfg["image_size"]), - waypoint_spacing=int(data_config["waypoint_spacing"]), - len_traj_pred=int(network_cfg["len_traj_pred"]), - learn_angle=bool(network_cfg["learn_angle"]), - context_size=int(network_cfg["context_size"]), - context_type=str(dataset_cfg["context_type"]), - end_slack=int(data_config["end_slack"]), - goals_per_obs=int(data_config["goals_per_obs"]), - normalize=bool(dataset_cfg["normalize"]), - modality_id=int(data_config["modality_id"]), - metric_waypoint_spacing=float(data_config.get("metric_waypoint_spacing", 1.0)), - clip_image_size=tuple(dataset_cfg.get("clip_image_size", (224, 224))), - clip_model=str(dataset_cfg.get("clip_model", "ViT-B/32")), + train_ratio = data_config.get("train_ratio") + common_kwargs = dict( + data_folder=data_config["data_folder"], + dataset_name=str(dataset_name), + image_size=tuple(dataset_cfg["image_size"]), + waypoint_spacing=int(data_config["waypoint_spacing"]), + len_traj_pred=int(network_cfg["len_traj_pred"]), + learn_angle=bool(network_cfg["learn_angle"]), + context_size=int(network_cfg["context_size"]), + context_type=str(dataset_cfg["context_type"]), + end_slack=int(data_config["end_slack"]), + goals_per_obs=int(data_config["goals_per_obs"]), + normalize=bool(dataset_cfg["normalize"]), + modality_id=int(data_config["modality_id"]), + metric_waypoint_spacing=float(data_config.get("metric_waypoint_spacing", 1.0)), + clip_image_size=tuple(dataset_cfg.get("clip_image_size", (224, 224))), + clip_model=str(dataset_cfg.get("clip_model", "ViT-B/32")), + ) + + if train_ratio is not None: + traj_names_path = Path(data_config["data_folder"]) / "traj_names.txt" + if not traj_names_path.exists(): + raise FileNotFoundError(f"traj_names.txt not found in data_folder: {traj_names_path}") + train_dataset = EdgeNavigationDataset( + data_split_folder=traj_names_path, + train_ratio=float(train_ratio), + split_type="train", + **common_kwargs, + ) + test_dataset = EdgeNavigationDataset( + data_split_folder=traj_names_path, + train_ratio=float(train_ratio), + split_type="test", + **common_kwargs, ) - if data_split_type == "train": - train_datasets.append(dataset) - train_eval_datasets[f"{dataset_name}_{data_split_type}"] = dataset - else: - test_datasets[f"{dataset_name}_{data_split_type}"] = dataset + train_datasets.append(train_dataset) + train_eval_datasets[f"{dataset_name}_train"] = train_dataset + test_datasets[f"{dataset_name}_test"] = test_dataset + else: + for data_split_type in ("train", "test"): + if data_split_type not in data_config: + continue + dataset = EdgeNavigationDataset( + data_split_folder=data_config[data_split_type], + **common_kwargs, + ) + if data_split_type == "train": + train_datasets.append(dataset) + train_eval_datasets[f"{dataset_name}_{data_split_type}"] = dataset + else: + test_datasets[f"{dataset_name}_{data_split_type}"] = dataset if not train_datasets: raise ValueError("No train datasets were configured in dataset.yaml.") diff --git a/training/config/dataset.yaml b/training/config/dataset.yaml index 9c29129..43f59ad 100644 --- a/training/config/dataset.yaml +++ b/training/config/dataset.yaml @@ -4,10 +4,9 @@ normalize: true clip_model: ViT-B/32 datasets: - sample: - data_folder: /home//nomad_dataset/sample - train: /home//data_splits/sample/train/ # path to train folder with traj_names.txt - test: /home//data_splits/sample/test/ # path to test folder with traj_names.txt + navvla_test: + data_folder: /home/orne/kasai_ws/src/NavVLA/training/dataset/navvla_20260515_111925 + train_ratio: 0.7 end_slack: 3 # because many trajectories end in collisions goals_per_obs: 1 # how many goals are sampled per observation waypoint_spacing: 1 diff --git a/training/config/train.yaml b/training/config/train.yaml index 98bc9d8..1771516 100644 --- a/training/config/train.yaml +++ b/training/config/train.yaml @@ -1,14 +1,14 @@ weights_path: training/weights/omnivla-edge.pth run_root_dir: training/runs -epochs: 10 -batch_size: 4 -learning_rate: 1.0e-4 -weight_decay: 0.0 -num_workers: 4 +epochs: 40 +batch_size: 28 +learning_rate: 5.0e-5 +weight_decay: 1.0e-2 +num_workers: 8 seed: 42 -save_freq: 1 eval_freq: 1 max_train_steps: null max_test_steps: null +resume_from: null diff --git a/training/data/dataset.py b/training/data/dataset.py index c66cc8f..21ebbe6 100644 --- a/training/data/dataset.py +++ b/training/data/dataset.py @@ -26,6 +26,7 @@ "feat_text", "current_img", "actions", + "dist_to_goal", ) KEY_ALIASES = { @@ -112,6 +113,8 @@ def __init__( metric_waypoint_spacing: float = 1.0, clip_image_size: Tuple[int, int] = (224, 224), clip_model: str = DEFAULT_CLIP_MODEL, + train_ratio: float | None = None, + split_type: str = "train", ) -> None: self.data_folder = Path(data_folder) self.data_split_folder = Path(data_split_folder) @@ -130,6 +133,9 @@ def __init__( self.clip_image_size = tuple(int(v) for v in clip_image_size) self.clip_model = str(clip_model) self.dummy_text_feature: torch.Tensor | None = None + self.action_horizon = self.len_traj_pred * self.waypoint_spacing + self.train_ratio = train_ratio + self.split_type = split_type self.text_encoder = None self.prompt_cache: Dict[str, List[str]] = {} self.text_feature_cache: Dict[str, torch.Tensor] = {} @@ -213,10 +219,25 @@ def build_sample_index(self) -> List[Tuple[str, int, int]]: traj_data = self.load_trajectory(traj_name) traj_len = len(self.read_positions(traj_data)) end_time = traj_len - self.end_slack - action_horizon - for curr_time in range(begin_time, end_time): + + all_times = [ + t for t in range(begin_time, end_time) + if min(traj_len - self.end_slack - 1, t + action_horizon) > t + ] + + if self.train_ratio is not None: + # トラジェクトリ名をシードにすることでtrain/test両データセットで同じシャッフル順になる + rng = np.random.default_rng(seed=abs(hash(traj_name)) % (2**31)) + shuffled = rng.permutation(len(all_times)) + n_train = int(len(shuffled) * self.train_ratio) + indices = shuffled[:n_train] if self.split_type == "train" else shuffled[n_train:] + selected_times = [all_times[i] for i in indices] + else: + selected_times = all_times + + for curr_time in selected_times: max_goal_time = min(traj_len - self.end_slack - 1, curr_time + action_horizon) - if max_goal_time > curr_time: - samples.append((traj_name, curr_time, max_goal_time)) + samples.append((traj_name, curr_time, max_goal_time)) return samples def __len__(self) -> int: @@ -392,6 +413,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: map_images = self.build_map_images(obs_images, goal_image) feat_text = self.build_language_feature(traj_name, traj_data, curr_time) + dist_to_goal = torch.as_tensor( + (goal_time - curr_time) / self.action_horizon, dtype=torch.float32 + ) return { "obs_images": obs_images, "goal_pose": goal_pose, @@ -401,6 +425,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: "feat_text": feat_text, "current_img": current_img, "actions": actions, + "dist_to_goal": dist_to_goal, } diff --git a/training/eval.py b/training/eval.py index 485d5cd..2e5cba2 100644 --- a/training/eval.py +++ b/training/eval.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from torch.utils.data import DataLoader +from tqdm import tqdm class Test: @@ -27,15 +28,19 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: self.model.eval() total_loss = 0.0 + total_action_loss = 0.0 + total_dist_loss = 0.0 total_batches = 0 with torch.no_grad(): - for step, raw_batch in enumerate(self.loader, start=1): + total_steps = len(self.loader) if max_steps is None else min(len(self.loader), max_steps) + progress = tqdm(self.loader, total=total_steps, desc="eval", leave=False) + for step, raw_batch in enumerate(progress, start=1): if max_steps is not None and step > max_steps: break batch = {key: value.to(self.device) for key, value in raw_batch.items()} - action_pred, _, _ = self.model( + action_pred, dist_pred, _ = self.model( batch["obs_images"], batch["goal_pose"].float(), batch["map_images"], @@ -44,11 +49,20 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: batch["feat_text"].float(), batch["current_img"], ) - loss = F.l1_loss(action_pred, batch["actions"].float()) + action_loss = F.l1_loss(action_pred, batch["actions"].float()) + dist_loss = F.l1_loss(dist_pred.squeeze(-1), batch["dist_to_goal"].float()) + loss = action_loss + dist_loss total_loss += float(loss.detach().cpu()) + total_action_loss += float(action_loss.detach().cpu()) + total_dist_loss += float(dist_loss.detach().cpu()) total_batches += 1 + progress.set_postfix(loss=f"{float(loss.detach().cpu()):.4f}") if total_batches == 0: raise RuntimeError("Test loader produced no batches.") - return {"loss": total_loss / total_batches} + return { + "loss": total_loss / total_batches, + "action_loss": total_action_loss / total_batches, + "dist_loss": total_dist_loss / total_batches, + } diff --git a/training/loop.py b/training/loop.py index f403df0..dcb0a7f 100644 --- a/training/loop.py +++ b/training/loop.py @@ -2,6 +2,7 @@ from __future__ import annotations +from datetime import datetime import random from pathlib import Path from typing import Dict @@ -16,6 +17,59 @@ from training.train import Train +def resolve_path(navvla_root: Path, raw_path: object) -> Path: + path = Path(str(raw_path)) + return path if path.is_absolute() else (navvla_root / path).resolve() + + +def make_timestamped_run_dir(run_root_dir: Path) -> Path: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_dir = run_root_dir / timestamp + suffix = 1 + while run_dir.exists(): + suffix += 1 + run_dir = run_root_dir / f"{timestamp}_{suffix:02d}" + return run_dir + + +def extract_state_dict(checkpoint: object) -> Dict[str, torch.Tensor]: + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + return checkpoint["state_dict"] + if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + return checkpoint["model_state_dict"] + if isinstance(checkpoint, dict): + return checkpoint + raise ValueError("Checkpoint must be a state_dict or a dict containing state_dict/model_state_dict.") + + +def save_checkpoint( + checkpoint_path: Path, + model_path: Path, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + epoch: int, + train_cfg: Dict[str, object], + network_cfg: Dict[str, object], + dataset_cfg: Dict[str, object], +) -> None: + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + state_dict = model.state_dict() + torch.save( + { + "epoch": epoch, + "state_dict": state_dict, + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_cfg": train_cfg, + "network_cfg": network_cfg, + "dataset_cfg": dataset_cfg, + }, + checkpoint_path, + ) + torch.save(state_dict, model_path) + + def validate_config( train_cfg: Dict[str, object], network_cfg: Dict[str, object], @@ -34,7 +88,6 @@ def validate_config( "weight_decay", "num_workers", "seed", - "save_freq", "eval_freq", ), ), @@ -80,13 +133,14 @@ def main_loop( test_dataloaders: Dict[str, DataLoader], ) -> int: validate_config(train_cfg, network_cfg, dataset_cfg) - weights_path = Path(str(train_cfg["weights_path"])) - weights_path = (weights_path if weights_path.is_absolute() else (navvla_root / weights_path).resolve()) + weights_path = resolve_path(navvla_root, train_cfg["weights_path"]) - run_dir = Path(str(train_cfg["run_root_dir"])) - run_dir = run_dir if run_dir.is_absolute() else (navvla_root / run_dir).resolve() + run_root_dir = Path(str(train_cfg["run_root_dir"])) + run_root_dir = run_root_dir if run_root_dir.is_absolute() else (navvla_root / run_root_dir).resolve() + run_dir = make_timestamped_run_dir(run_root_dir) print(f"[NavVLA] OmniVLA-edge weights: {weights_path}") + print(f"[NavVLA] Run root directory: {run_root_dir}") print(f"[NavVLA] Run directory: {run_dir}") seed = int(train_cfg["seed"]) @@ -113,7 +167,7 @@ def main_loop( raise FileNotFoundError(f"OmniVLA-edge weights not found: {weights_path}") checkpoint = torch.load(weights_path, map_location="cpu") - state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint + state_dict = extract_state_dict(checkpoint) model.load_state_dict(state_dict, strict=True) model = model.to(device) @@ -123,6 +177,25 @@ def main_loop( weight_decay=float(train_cfg["weight_decay"]), ) + total_epochs = int(train_cfg["epochs"]) + # base_lrs を fresh の lr=learning_rate に固定するため、resume 前に構築 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs) + + start_epoch = 1 + resume_from = train_cfg.get("resume_from") + if resume_from: + resume_path = resolve_path(navvla_root, resume_from) + if not resume_path.exists(): + raise FileNotFoundError(f"Resume checkpoint not found: {resume_path}") + resume_checkpoint = torch.load(resume_path, map_location="cpu") + model.load_state_dict(extract_state_dict(resume_checkpoint), strict=True) + if isinstance(resume_checkpoint, dict) and "optimizer_state_dict" in resume_checkpoint: + optimizer.load_state_dict(resume_checkpoint["optimizer_state_dict"]) + if isinstance(resume_checkpoint, dict) and "epoch" in resume_checkpoint: + start_epoch = int(resume_checkpoint["epoch"]) + 1 + scheduler.load_state_dict(resume_checkpoint["scheduler_state_dict"]) + print(f"[NavVLA] Resumed from {resume_path} at epoch={start_epoch - 1}") + Trainer = Train(model=model, loader=train_loader, optimizer=optimizer, device=device) TrainEvaluators = { dataset_type: Test(model=model, loader=loader, device=device) @@ -135,20 +208,31 @@ def main_loop( max_train_steps = train_cfg.get("max_train_steps") max_test_steps = train_cfg.get("max_test_steps") - save_freq = int(train_cfg["save_freq"]) eval_freq = int(train_cfg["eval_freq"]) + best_loss = float("inf") if SummaryWriter is None: raise ImportError("TensorBoard is not available. Install it with: pip install tensorboard") tensorboard_dir = train_cfg.get("tensorboard_log_dir", run_dir / "tensorboard") writer = SummaryWriter(log_dir=str(tensorboard_dir)) - print(f"[NavVLA] TensorBoard: {tensorboard_dir}") + print(f"[NavVLA] TensorBoard: tensorboard --logdir {tensorboard_dir}") + + if start_epoch > total_epochs: + print(f"[NavVLA] Nothing to train: start_epoch={start_epoch} > epochs={total_epochs}") + writer.close() + return 0 - for epoch in range(1, int(train_cfg["epochs"]) + 1): + for epoch in range(start_epoch, total_epochs + 1): train_metrics = Trainer.run( max_steps=None if max_train_steps is None else int(max_train_steps) ) - print(f"[NavVLA] epoch={epoch} train={train_metrics}") + scheduler.step() + current_lr = scheduler.get_last_lr()[0] + print(f"[NavVLA] epoch={epoch} train={train_metrics} lr={current_lr:.2e}") writer.add_scalar("loss/train_total", train_metrics["loss"], epoch) + writer.add_scalar("loss/train_action", train_metrics["action_loss"], epoch) + writer.add_scalar("loss/train_dist", train_metrics["dist_loss"], epoch) + writer.add_scalar("lr", current_lr, epoch) + writer.flush() train_eval_losses = [] for dataset_type, evaluator in TrainEvaluators.items(): @@ -158,6 +242,7 @@ def main_loop( writer.add_scalar(f"loss/train/{dataset_type}", metrics["loss"], epoch) if train_eval_losses: writer.add_scalar("loss/train_datasets_total", float(np.mean(train_eval_losses)), epoch) + writer.flush() if epoch % eval_freq == 0: eval_losses = [] @@ -170,15 +255,41 @@ def main_loop( writer.add_scalar(f"loss/eval/{dataset_type}", test_metrics["loss"], epoch) if eval_losses: writer.add_scalar("loss/eval_total", float(np.mean(eval_losses)), epoch) + writer.flush() - if epoch % save_freq == 0: - run_dir.mkdir(parents=True, exist_ok=True) - checkpoint_path = run_dir / "model_latest.pth" - torch.save(model.state_dict(), checkpoint_path) - print(f"[NavVLA] saved={checkpoint_path}") + if eval_losses: + candidate_loss = float(np.mean(eval_losses)) + elif train_eval_losses: + candidate_loss = float(np.mean(train_eval_losses)) + else: + candidate_loss = train_metrics["loss"] - writer.close() + if candidate_loss < best_loss: + best_loss = candidate_loss + save_checkpoint( + checkpoint_path=run_dir / "checkpoint_best.pth", + model_path=run_dir / "model_best.pth", + model=model, + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + train_cfg=train_cfg, + network_cfg=network_cfg, + dataset_cfg=dataset_cfg, + ) + print(f"[NavVLA] best model updated: loss={best_loss:.6f} epoch={epoch}") + + save_checkpoint( + checkpoint_path=run_dir / "checkpoint_latest.pth", + model_path=run_dir / "model_latest.pth", + model=model, + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + train_cfg=train_cfg, + network_cfg=network_cfg, + dataset_cfg=dataset_cfg, + ) - run_dir.mkdir(parents=True, exist_ok=True) - torch.save(model.state_dict(), run_dir / "model_latest.pth") + writer.close() return 0 diff --git a/training/train.py b/training/train.py index 451f006..68df5a0 100644 --- a/training/train.py +++ b/training/train.py @@ -31,6 +31,7 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: total_loss = 0.0 total_action_loss = 0.0 + total_dist_loss = 0.0 total_batches = 0 total_steps = len(self.loader) if max_steps is None else min(len(self.loader), max_steps) @@ -42,7 +43,7 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: batch = {key: value.to(self.device) for key, value in raw_batch.items()} self.optimizer.zero_grad(set_to_none=True) - action_pred, _, _ = self.model( + action_pred, dist_pred, _ = self.model( batch["obs_images"], batch["goal_pose"].float(), batch["map_images"], @@ -52,14 +53,16 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: batch["current_img"], ) action_loss = F.l1_loss(action_pred, batch["actions"].float()) - action_loss.backward() + dist_loss = F.l1_loss(dist_pred.squeeze(-1), batch["dist_to_goal"].float()) + loss = action_loss + dist_loss + loss.backward() self.optimizer.step() - loss_value = float(action_loss.detach().cpu()) - total_loss += loss_value - total_action_loss += loss_value + total_loss += float(loss.detach().cpu()) + total_action_loss += float(action_loss.detach().cpu()) + total_dist_loss += float(dist_loss.detach().cpu()) total_batches += 1 - progress.set_postfix(loss=f"{loss_value:.4f}") + progress.set_postfix(loss=f"{float(loss.detach().cpu()):.4f}") if total_batches == 0: raise RuntimeError("Train loader produced no batches.") @@ -67,4 +70,5 @@ def run(self, max_steps: Optional[int] = None) -> Dict[str, float]: return { "loss": total_loss / total_batches, "action_loss": total_action_loss / total_batches, + "dist_loss": total_dist_loss / total_batches, }