From d4040f2446088ae9ed559a946fb69bb73641654a Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 11:28:14 -0700 Subject: [PATCH 01/42] better native module debugging --- dimos/core/native_module.py | 41 ++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 6a93e6453a..2ebe0a7f2d 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -146,7 +146,13 @@ def start(self) -> None: env = {**os.environ, **self.config.extra_env} cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - logger.info("Starting native process", cmd=" ".join(cmd), cwd=cwd) + module_name = type(self).__name__ + logger.info( + f"Starting native process: {module_name}", + module=module_name, + cmd=" ".join(cmd), + cwd=cwd, + ) self._process = subprocess.Popen( cmd, env=env, @@ -154,7 +160,11 @@ def start(self) -> None: stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - logger.info("Native process started", pid=self._process.pid) + logger.info( + f"Native process started: {module_name}", + module=module_name, + pid=self._process.pid, + ) self._stopping = False self._watchdog = threading.Thread(target=self._watch_process, daemon=True) @@ -193,10 +203,27 @@ def _watch_process(self) -> None: if self._stopping: return + + module_name = type(self).__name__ + exe_name = Path(self.config.executable).name if self.config.executable else "unknown" + + # Collect any remaining stderr for the crash report + last_stderr = "" + if self._process.stderr and not self._process.stderr.closed: + try: + remaining = self._process.stderr.read() + if remaining: + last_stderr = remaining.decode("utf-8", errors="replace").strip() + except Exception: + pass + logger.error( - "Native process died unexpectedly", + f"Native process crashed: {module_name} ({exe_name})", + module=module_name, + executable=exe_name, pid=self._process.pid, returncode=rc, + last_stderr=last_stderr[:500] if last_stderr else None, ) self.stop() @@ -265,12 +292,16 @@ def _maybe_build(self) -> None: if line.strip(): logger.warning(line) if proc.returncode != 0: + stderr_tail = stderr.decode("utf-8", errors="replace").strip()[-1000:] raise RuntimeError( - f"Build command failed (exit {proc.returncode}): {self.config.build_command}" + f"Build command failed (exit {proc.returncode}): {self.config.build_command}\n" + f"stderr: {stderr_tail}" ) if not exe.exists(): raise FileNotFoundError( - f"Build command succeeded but executable still not found: {exe}" + f"Build command succeeded but executable still not found: {exe}\n" + f"Build output may have been written to a different path. " + f"Check that build_command produces the executable at the expected location." ) def _collect_topics(self) -> dict[str, str]: From c581f9ce9bf1c17cf04950dc5d643f11fe1fda31 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 13:29:23 -0700 Subject: [PATCH 02/42] add unity sim, part 1 --- dimos/robot/all_blueprints.py | 1 + dimos/simulation/unity/__init__.py | 0 dimos/simulation/unity/__main__.py | 20 + dimos/simulation/unity/blueprint.py | 69 +++ dimos/simulation/unity/module.py | 742 +++++++++++++++++++++++ dimos/simulation/unity/test_unity_sim.py | 316 ++++++++++ dimos/utils/ros1.py | 397 ++++++++++++ 7 files changed, 1545 insertions(+) create mode 100644 dimos/simulation/unity/__init__.py create mode 100644 dimos/simulation/unity/__main__.py create mode 100644 dimos/simulation/unity/blueprint.py create mode 100644 dimos/simulation/unity/module.py create mode 100644 dimos/simulation/unity/test_unity_sim.py create mode 100644 dimos/utils/ros1.py diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index e82cb656ce..45808971e7 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -86,6 +86,7 @@ "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", + "unity-sim-blueprint": "dimos.simulation.unity.blueprint:unity_sim_blueprint", "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", diff --git a/dimos/simulation/unity/__init__.py b/dimos/simulation/unity/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/simulation/unity/__main__.py b/dimos/simulation/unity/__main__.py new file mode 100644 index 0000000000..0e25da7ed7 --- /dev/null +++ b/dimos/simulation/unity/__main__.py @@ -0,0 +1,20 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the standalone Unity sim blueprint: python -m dimos.simulation.unity""" + +from dimos.simulation.unity.blueprint import main + +if __name__ == "__main__": + main() diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py new file mode 100644 index 0000000000..5f5dd139fd --- /dev/null +++ b/dimos/simulation/unity/blueprint.py @@ -0,0 +1,69 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone Unity sim blueprint — interactive test of the Unity bridge. + +Launches the Unity simulator, displays lidar + camera in Rerun, and accepts +keyboard teleop via TUI. No navigation stack — just raw sim data. + +Usage: + python -m dimos.simulation.unity.blueprint +""" + +from __future__ import annotations + +from typing import Any + +from dimos.core.blueprints import autoconnect +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.simulation.unity.module import UnityBridgeModule +from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge + + +def _rerun_blueprint() -> Any: + import rerun.blueprint as rrb + + return rrb.Blueprint( + rrb.Vertical( + rrb.Spatial3DView(origin="world", name="3D"), + rrb.Spatial2DView(origin="world/color_image", name="Camera"), + row_shares=[2, 1], + ), + ) + + +rerun_config = { + "blueprint": _rerun_blueprint, + "pubsubs": [LCM()], + "visual_override": { + "world/camera_info": UnityBridgeModule.rerun_suppress_camera_info, + }, + "static": { + "world/color_image": UnityBridgeModule.rerun_static_pinhole, + }, +} + + +unity_sim_blueprint = autoconnect( + UnityBridgeModule.blueprint(), + rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config), +) + + +def main() -> None: + unity_sim_blueprint.build().loop() + + +if __name__ == "__main__": + main() diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py new file mode 100644 index 0000000000..f1933cbde3 --- /dev/null +++ b/dimos/simulation/unity/module.py @@ -0,0 +1,742 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UnityBridgeModule: TCP bridge to the VLA Challenge Unity simulator. + +Implements the ROS-TCP-Endpoint binary protocol to communicate with Unity +directly — no ROS dependency needed, no Unity-side changes. + +Unity sends simulated sensor data (lidar PointCloud2, compressed camera images). +We send back vehicle PoseStamped updates so Unity renders the robot position. + +Protocol (per message on the TCP stream): + [4 bytes LE uint32] destination string length + [N bytes] destination string (topic name or __syscommand) + [4 bytes LE uint32] message payload length + [M bytes] payload (ROS1-serialized message, or JSON for syscommands) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +import math +import os +from pathlib import Path +import platform +from queue import Empty, Queue +import socket +import struct +import subprocess +import threading +import time +from typing import Any +import zipfile + +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.logging_config import setup_logger +from dimos.utils.ros1 import ( + deserialize_compressed_image, + deserialize_pointcloud2, + serialize_pose_stamped, +) + +logger = setup_logger() +PI = math.pi + +# Google Drive folder containing environment zips +_GDRIVE_FOLDER_ID = "1UD5v6cSfcwIMWmsq9WSk7blJut4kgb-1" +_DEFAULT_SCENE = "office_1" +_SUPPORTED_SYSTEMS = {"Linux"} +_SUPPORTED_ARCHS = {"x86_64", "AMD64"} + + +# --------------------------------------------------------------------------- +# TCP protocol helpers +# --------------------------------------------------------------------------- + + +def _recvall(sock: socket.socket, size: int) -> bytes: + buf = bytearray(size) + view = memoryview(buf) + pos = 0 + while pos < size: + n = sock.recv_into(view[pos:], size - pos) + if not n: + raise OSError("Connection closed") + pos += n + return bytes(buf) + + +def _read_tcp_message(sock: socket.socket) -> tuple[str, bytes]: + dest_len = struct.unpack(" 0 else b"" + return dest, msg_data + + +def _write_tcp_message(sock: socket.socket, destination: str, data: bytes) -> None: + dest_bytes = destination.encode("utf-8") + sock.sendall( + struct.pack(" None: + dest_bytes = command.encode("utf-8") + json_bytes = json.dumps(params).encode("utf-8") + sock.sendall( + struct.pack(" Path: + """Download a Unity environment zip from Google Drive and extract it. + + Returns the path to the Model.x86_64 binary. + """ + try: + import gdown # type: ignore[import-untyped] + except ImportError: + raise RuntimeError( + "Unity sim binary not found and 'gdown' is not installed for auto-download. " + "Install it with: pip install gdown\n" + "Or manually download from: " + f"https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" + ) + + dest_dir.mkdir(parents=True, exist_ok=True) + zip_path = dest_dir / f"{scene}.zip" + + if not zip_path.exists(): + print("\n" + "=" * 70, flush=True) + print(f" DOWNLOADING UNITY SIMULATOR — scene: '{scene}'", flush=True) + print(" Source: Google Drive (VLA Challenge environments)", flush=True) + print(" Size: ~130-580 MB per scene (depends on scene complexity)", flush=True) + print(f" Destination: {dest_dir}", flush=True) + print(" This is a one-time download. Subsequent runs use the cache.", flush=True) + print("=" * 70 + "\n", flush=True) + gdown.download_folder( + id=_GDRIVE_FOLDER_ID, + output=str(dest_dir), + quiet=False, + ) + # gdown downloads all scenes into a subfolder; find our zip + for candidate in dest_dir.rglob(f"{scene}.zip"): + zip_path = candidate + break + + if not zip_path.exists(): + raise FileNotFoundError( + f"Failed to download scene '{scene}'. " + f"Check https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" + ) + + # Extract + extract_dir = dest_dir / scene + if not extract_dir.exists(): + logger.info(f"Extracting {zip_path}...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(dest_dir) + + binary = extract_dir / "environment" / "Model.x86_64" + if not binary.exists(): + raise FileNotFoundError( + f"Extracted scene but Model.x86_64 not found at {binary}. " + f"Expected structure: {scene}/environment/Model.x86_64" + ) + + binary.chmod(binary.stat().st_mode | 0o111) + return binary + + +# --------------------------------------------------------------------------- +# Platform validation +# --------------------------------------------------------------------------- + + +def _validate_platform() -> None: + """Raise if the current platform can't run the Unity x86_64 binary.""" + system = platform.system() + arch = platform.machine() + + if system not in _SUPPORTED_SYSTEMS: + raise RuntimeError( + f"Unity simulator requires Linux x86_64 but running on {system} {arch}. " + f"macOS and Windows are not supported (the binary is a Linux ELF executable). " + f"Use a Linux VM, Docker, or WSL2." + ) + + if arch not in _SUPPORTED_ARCHS: + raise RuntimeError( + f"Unity simulator requires x86_64 but running on {arch}. " + f"ARM64 Linux is not supported. Use an x86_64 machine or emulation layer." + ) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class UnityBridgeConfig(ModuleConfig): + """Configuration for the Unity bridge / vehicle simulator. + + Set ``unity_binary=""`` to skip launching Unity and connect to an + already-running instance. Set ``auto_download=True`` (default) to + automatically download the scene if the binary is missing. + """ + + # Path to the Unity x86_64 binary. Relative paths resolved from cwd. + # Leave empty to auto-detect from cache or auto-download. + unity_binary: str = "" + + # Scene name for auto-download (e.g. "office_1", "hotel_room_1"). + # Only used when unity_binary is not found and auto_download is True. + unity_scene: str = _DEFAULT_SCENE + + # Directory to download/cache Unity scenes. + unity_cache_dir: str = "~/.cache/smartnav/unity_envs" + + # Auto-download the scene from Google Drive if binary is missing. + auto_download: bool = True + + # Max seconds to wait for Unity to connect after launch. + unity_connect_timeout: float = 30.0 + + # TCP server settings (we listen; Unity connects to us). + unity_host: str = "0.0.0.0" + unity_port: int = 10000 + + # Run Unity with no visible window (set -batchmode -nographics). + # Note: headless mode may not produce camera images. + headless: bool = False + + # Extra CLI args to pass to the Unity binary. + unity_extra_args: list[str] = field(default_factory=list) + + # Vehicle parameters + sensor_offset_x: float = 0.0 + sensor_offset_y: float = 0.0 + vehicle_height: float = 0.75 + + # Initial vehicle pose + init_x: float = 0.0 + init_y: float = 0.0 + init_z: float = 0.0 + init_yaw: float = 0.0 + + # Kinematic sim rate (Hz) for odometry integration + sim_rate: float = 200.0 + + +# --------------------------------------------------------------------------- +# Module +# --------------------------------------------------------------------------- + + +class UnityBridgeModule(Module[UnityBridgeConfig]): + """TCP bridge to the Unity simulator with kinematic odometry integration. + + Ports: + cmd_vel (In[Twist]): Velocity commands. + terrain_map (In[PointCloud2]): Terrain for Z adjustment. + odometry (Out[Odometry]): Vehicle state at sim_rate. + registered_scan (Out[PointCloud2]): Lidar from Unity. + color_image (Out[Image]): RGB camera from Unity (1920x640 panoramic). + semantic_image (Out[Image]): Semantic segmentation from Unity. + camera_info (Out[CameraInfo]): Camera intrinsics. + """ + + default_config = UnityBridgeConfig + + cmd_vel: In[Twist] + terrain_map: In[PointCloud2] + odometry: Out[Odometry] + registered_scan: Out[PointCloud2] + color_image: Out[Image] + semantic_image: Out[Image] + camera_info: Out[CameraInfo] + + # Rerun static config for 3D camera projection — use this when building + # your rerun_config so the panoramic image renders correctly in 3D. + # + # Usage: + # rerun_config = { + # "static": {"world/color_image": UnityBridgeModule.rerun_static_pinhole}, + # "visual_override": {"world/camera_info": UnityBridgeModule.rerun_suppress_camera_info}, + # } + @staticmethod + def rerun_static_pinhole(rr: Any) -> list[Any]: + """Static Pinhole + Transform3D for the Unity panoramic camera.""" + width, height = 1920, 640 + hfov_rad = math.radians(120.0) + fx = (width / 2.0) / math.tan(hfov_rad / 2.0) + fy = fx + cx, cy = width / 2.0, height / 2.0 + return [ + rr.Pinhole( + resolution=[width, height], + focal_length=[fx, fy], + principal_point=[cx, cy], + camera_xyz=rr.ViewCoordinates.RDF, + ), + rr.Transform3D( + parent_frame="tf#/sensor", + translation=[0.0, 0.0, 0.1], + rotation=rr.Quaternion(xyzw=[0.5, -0.5, 0.5, -0.5]), + ), + ] + + @staticmethod + def rerun_suppress_camera_info(_: Any) -> None: + """Suppress CameraInfo logging — the static pinhole handles 3D projection.""" + return None + + # ---- lifecycle -------------------------------------------------------- + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._x = self.config.init_x + self._y = self.config.init_y + self._z = self.config.init_z + self.config.vehicle_height + self._roll = 0.0 + self._pitch = 0.0 + self._yaw = self.config.init_yaw + self._terrain_z = self.config.init_z + self._fwd_speed = 0.0 + self._left_speed = 0.0 + self._yaw_rate = 0.0 + self._cmd_lock = threading.Lock() + self._state_lock = threading.Lock() + self._running = False + self._sim_thread: threading.Thread | None = None + self._unity_thread: threading.Thread | None = None + self._unity_connected = False + self._unity_ready = threading.Event() + self._unity_process: subprocess.Popen | None = None # type: ignore[type-arg] + self._send_queue: Queue[tuple[str, bytes]] = Queue() + + def __getstate__(self) -> dict[str, Any]: # type: ignore[override] + state: dict[str, Any] = super().__getstate__() # type: ignore[no-untyped-call] + for key in ( + "_cmd_lock", + "_state_lock", + "_sim_thread", + "_unity_thread", + "_unity_process", + "_send_queue", + "_unity_ready", + ): + state.pop(key, None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self._cmd_lock = threading.Lock() + self._state_lock = threading.Lock() + self._sim_thread = None + self._unity_thread = None + self._unity_process = None + self._send_queue = Queue() + self._unity_ready = threading.Event() + self._running = False + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) + self._disposables.add(Disposable(self.terrain_map.subscribe(self._on_terrain))) + self._running = True + self._sim_thread = threading.Thread(target=self._sim_loop, daemon=True) + self._sim_thread.start() + self._unity_thread = threading.Thread(target=self._unity_loop, daemon=True) + self._unity_thread.start() + self._launch_unity() + + @rpc + def stop(self) -> None: + self._running = False + if self._sim_thread: + self._sim_thread.join(timeout=2.0) + if self._unity_thread: + self._unity_thread.join(timeout=2.0) + if self._unity_process is not None and self._unity_process.poll() is None: + import signal as _sig + + logger.info(f"Stopping Unity (pid={self._unity_process.pid})") + self._unity_process.send_signal(_sig.SIGTERM) + try: + self._unity_process.wait(timeout=5) + except Exception: + self._unity_process.kill() + self._unity_process = None + super().stop() + + # ---- Unity process management ----------------------------------------- + + def _resolve_binary(self) -> Path | None: + """Find the Unity binary, downloading if needed. Returns None to skip launch.""" + cfg = self.config + + # Explicit path provided + if cfg.unity_binary: + p = Path(cfg.unity_binary) + if not p.is_absolute(): + p = Path.cwd() / p + if not p.exists(): + p = (Path(__file__).resolve().parent / cfg.unity_binary).resolve() + if p.exists(): + return p + if not cfg.auto_download: + logger.error( + f"Unity binary not found at {p} and auto_download is disabled. " + f"Set unity_binary to a valid path or enable auto_download." + ) + return None + + # Auto-download + if cfg.auto_download: + _validate_platform() + cache = Path(cfg.unity_cache_dir).expanduser() + candidate = cache / cfg.unity_scene / "environment" / "Model.x86_64" + if candidate.exists(): + return candidate + logger.info(f"Unity binary not found, downloading scene '{cfg.unity_scene}'...") + return _download_unity_scene(cfg.unity_scene, cache) + + return None + + def _launch_unity(self) -> None: + """Launch the Unity simulator binary as a subprocess.""" + binary_path = self._resolve_binary() + if binary_path is None: + logger.info("No Unity binary — TCP server will wait for external connection") + return + + _validate_platform() + + if not os.access(binary_path, os.X_OK): + binary_path.chmod(binary_path.stat().st_mode | 0o111) + + cmd = [str(binary_path)] + if self.config.headless: + cmd.extend(["-batchmode", "-nographics"]) + cmd.extend(self.config.unity_extra_args) + + logger.info(f"Launching Unity: {' '.join(cmd)}") + env = {**os.environ} + if "DISPLAY" not in env and not self.config.headless: + env["DISPLAY"] = ":0" + + self._unity_process = subprocess.Popen( + cmd, + cwd=str(binary_path.parent), + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + logger.info(f"Unity pid={self._unity_process.pid}, waiting for TCP connection...") + + if self._unity_ready.wait(timeout=self.config.unity_connect_timeout): + logger.info("Unity connected") + else: + # Check if process died + rc = self._unity_process.poll() + if rc is not None: + logger.error( + f"Unity process exited with code {rc} before connecting. " + f"Check that DISPLAY is set and the binary is not corrupted." + ) + else: + logger.warning( + f"Unity did not connect within {self.config.unity_connect_timeout}s. " + f"The binary may still be loading — it will connect when ready." + ) + + # ---- input callbacks -------------------------------------------------- + + def _on_cmd_vel(self, twist: Twist) -> None: + with self._cmd_lock: + self._fwd_speed = twist.linear.x + self._left_speed = twist.linear.y + self._yaw_rate = twist.angular.z + + def _on_terrain(self, cloud: PointCloud2) -> None: + points, _ = cloud.as_numpy() + if len(points) == 0: + return + dx = points[:, 0] - self._x + dy = points[:, 1] - self._y + near = points[np.sqrt(dx * dx + dy * dy) < 0.5] + if len(near) >= 10: + with self._state_lock: + self._terrain_z = 0.8 * self._terrain_z + 0.2 * near[:, 2].mean() + + # ---- Unity TCP bridge ------------------------------------------------- + + def _unity_loop(self) -> None: + server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_sock.bind((self.config.unity_host, self.config.unity_port)) + server_sock.listen(1) + server_sock.settimeout(2.0) + logger.info(f"TCP server on :{self.config.unity_port}") + + while self._running: + try: + conn, addr = server_sock.accept() + logger.info(f"Unity connected from {addr}") + try: + self._bridge_connection(conn) + except Exception as e: + logger.info(f"Unity connection ended: {e}") + finally: + with self._state_lock: + self._unity_connected = False + conn.close() + except TimeoutError: + continue + except Exception as e: + if self._running: + logger.warning(f"TCP server error: {e}") + time.sleep(1.0) + + server_sock.close() + + def _bridge_connection(self, sock: socket.socket) -> None: + sock.settimeout(None) + with self._state_lock: + self._unity_connected = True + self._unity_ready.set() + + _write_tcp_command( + sock, + "__handshake", + { + "version": "v0.7.0", + "metadata": json.dumps({"protocol": "ROS2"}), + }, + ) + + halt = threading.Event() + sender = threading.Thread(target=self._unity_sender, args=(sock, halt), daemon=True) + sender.start() + + try: + while self._running and not halt.is_set(): + dest, data = _read_tcp_message(sock) + if dest == "": + continue + elif dest.startswith("__"): + self._handle_syscommand(dest, data) + else: + self._handle_unity_message(dest, data) + finally: + halt.set() + sender.join(timeout=2.0) + with self._state_lock: + self._unity_connected = False + + def _unity_sender(self, sock: socket.socket, halt: threading.Event) -> None: + while not halt.is_set(): + try: + dest, data = self._send_queue.get(timeout=1.0) + if dest == "__raw__": + sock.sendall(data) + else: + _write_tcp_message(sock, dest, data) + except Empty: + continue + except Exception: + halt.set() + + def _handle_syscommand(self, dest: str, data: bytes) -> None: + payload = data.rstrip(b"\x00") + try: + params = json.loads(payload.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + params = {} + + cmd = dest[2:] + logger.info(f"Unity syscmd: {cmd} {params}") + + if cmd == "topic_list": + resp = json.dumps( + { + "topics": ["/unity_sim/set_model_state", "/tf"], + "types": ["geometry_msgs/PoseStamped", "tf2_msgs/TFMessage"], + } + ).encode("utf-8") + hdr = b"__topic_list" + frame = struct.pack(" None: + if topic == "/registered_scan": + pc_result = deserialize_pointcloud2(data) + if pc_result is not None: + points, frame_id, ts = pc_result + if len(points) > 0: + self.registered_scan.publish( + PointCloud2.from_numpy(points, frame_id=frame_id, timestamp=ts) + ) + + elif "image" in topic and "compressed" in topic: + img_result = deserialize_compressed_image(data) + if img_result is not None: + img_bytes, _fmt, _frame_id, ts = img_result + try: + import cv2 + + decoded = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR) + if decoded is not None: + img = Image.from_numpy(decoded, frame_id="camera", ts=ts) + if "semantic" in topic: + self.semantic_image.publish(img) + else: + self.color_image.publish(img) + h, w = decoded.shape[:2] + self._publish_camera_info(w, h, ts) + except Exception as e: + logger.warning(f"Image decode failed ({topic}): {e}") + + def _publish_camera_info(self, width: int, height: int, ts: float) -> None: + fx = fy = height / 2.0 + cx, cy = width / 2.0, height / 2.0 + self.camera_info.publish( + CameraInfo( + height=height, + width=width, + distortion_model="plumb_bob", + D=[0.0, 0.0, 0.0, 0.0, 0.0], + K=[fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + P=[fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + frame_id="camera", + ts=ts, + ) + ) + + def _send_to_unity(self, topic: str, data: bytes) -> None: + with self._state_lock: + connected = self._unity_connected + if connected: + self._send_queue.put((topic, data)) + + # ---- kinematic sim loop ----------------------------------------------- + + def _sim_loop(self) -> None: + dt = 1.0 / self.config.sim_rate + + while self._running: + t0 = time.monotonic() + + with self._cmd_lock: + fwd, left, yaw_rate = self._fwd_speed, self._left_speed, self._yaw_rate + + prev_z = self._z + + self._yaw += dt * yaw_rate + if self._yaw > PI: + self._yaw -= 2 * PI + elif self._yaw < -PI: + self._yaw += 2 * PI + + cy, sy = math.cos(self._yaw), math.sin(self._yaw) + self._x += dt * cy * fwd - dt * sy * left + self._y += dt * sy * fwd + dt * cy * left + with self._state_lock: + terrain_z = self._terrain_z + self._z = terrain_z + self.config.vehicle_height + + now = time.time() + quat = Quaternion.from_euler(Vector3(self._roll, self._pitch, self._yaw)) + + self.odometry.publish( + Odometry( + ts=now, + frame_id="map", + child_frame_id="sensor", + pose=Pose( + position=[self._x, self._y, self._z], + orientation=[quat.x, quat.y, quat.z, quat.w], + ), + twist=Twist( + linear=[fwd, left, (self._z - prev_z) * self.config.sim_rate], + angular=[0.0, 0.0, yaw_rate], + ), + ) + ) + + self.tf.publish( + Transform( + translation=Vector3(self._x, self._y, self._z), + rotation=quat, + frame_id="map", + child_frame_id="sensor", + ts=now, + ), + Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", + ts=now, + ), + ) + + with self._state_lock: + unity_connected = self._unity_connected + if unity_connected: + self._send_to_unity( + "/unity_sim/set_model_state", + serialize_pose_stamped( + self._x, + self._y, + self._z, + quat.x, + quat.y, + quat.z, + quat.w, + ), + ) + + sleep_for = dt - (time.monotonic() - t0) + if sleep_for > 0: + time.sleep(sleep_for) diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py new file mode 100644 index 0000000000..fb8bf0b7a5 --- /dev/null +++ b/dimos/simulation/unity/test_unity_sim.py @@ -0,0 +1,316 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Unity simulator bridge module. + +Markers: + - No special markers needed for unit tests (all run on any platform). + - Tests that launch the actual Unity binary should use: + @pytest.mark.slow + @pytest.mark.skipif(platform.system() != "Linux" or platform.machine() not in ("x86_64", "AMD64"), + reason="Unity binary requires Linux x86_64") + @pytest.mark.skipif(not os.environ.get("DISPLAY"), reason="Unity requires a display (X11)") +""" + +import os +import pickle +import platform +import socket +import struct +import threading +import time + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.simulation.unity.module import ( + UnityBridgeConfig, + UnityBridgeModule, + _validate_platform, +) +from dimos.utils.ros1 import ROS1Writer, deserialize_pointcloud2 + +_is_linux_x86 = platform.system() == "Linux" and platform.machine() in ("x86_64", "AMD64") +_has_display = bool(os.environ.get("DISPLAY")) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _MockTransport: + def __init__(self): + self._messages = [] + self._subscribers = [] + + def publish(self, msg): + self._messages.append(msg) + for cb in self._subscribers: + cb(msg) + + def broadcast(self, _s, msg): + self.publish(msg) + + def subscribe(self, cb, *_a): + self._subscribers.append(cb) + return lambda: self._subscribers.remove(cb) + + +def _wire(module) -> dict[str, _MockTransport]: + ts = {} + for name in ( + "odometry", + "registered_scan", + "cmd_vel", + "terrain_map", + "color_image", + "semantic_image", + "camera_info", + ): + t = _MockTransport() + getattr(module, name)._transport = t + ts[name] = t + return ts + + +def _find_free_port() -> int: + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _build_ros1_pointcloud2(points: np.ndarray, frame_id: str = "map") -> bytes: + w = ROS1Writer() + w.u32(0) + w.time() + w.string(frame_id) + n = len(points) + w.u32(1) + w.u32(n) + w.u32(4) + for i, name in enumerate(["x", "y", "z", "intensity"]): + w.string(name) + w.u32(i * 4) + w.u8(7) + w.u32(1) + w.u8(0) + w.u32(16) + w.u32(16 * n) + data = np.column_stack([points, np.zeros(n, dtype=np.float32)]).astype(np.float32).tobytes() + w.u32(len(data)) + w.raw(data) + w.u8(1) + return w.bytes() + + +def _send_tcp(sock, dest: str, data: bytes): + d = dest.encode() + sock.sendall(struct.pack(" tuple[str, bytes]: + dl = struct.unpack("= 1 + received_pts, _ = ts["registered_scan"]._messages[0].as_numpy() + np.testing.assert_allclose(received_pts, pts, atol=0.01) + + +# --------------------------------------------------------------------------- +# Kinematic Sim — needs threading, ~1s, runs everywhere +# --------------------------------------------------------------------------- + + +class TestKinematicSim: + def test_odometry_published(self): + m = UnityBridgeModule(unity_binary="", sim_rate=100.0) + ts = _wire(m) + + m._running = True + m._sim_thread = threading.Thread(target=m._sim_loop, daemon=True) + m._sim_thread.start() + time.sleep(0.2) + m._running = False + m._sim_thread.join(timeout=2) + m.stop() + + assert len(ts["odometry"]._messages) > 5 + assert ts["odometry"]._messages[0].frame_id == "map" + + def test_cmd_vel_moves_robot(self): + m = UnityBridgeModule(unity_binary="", sim_rate=200.0) + ts = _wire(m) + + m._on_cmd_vel(Twist(linear=[1.0, 0.0, 0.0], angular=[0.0, 0.0, 0.0])) + m._running = True + m._sim_thread = threading.Thread(target=m._sim_loop, daemon=True) + m._sim_thread.start() + time.sleep(1.0) + m._running = False + m._sim_thread.join(timeout=2) + m.stop() + + last_odom = ts["odometry"]._messages[-1] + assert last_odom.x > 0.5 + + +# --------------------------------------------------------------------------- +# Rerun Config — fast, runs everywhere +# --------------------------------------------------------------------------- + + +class TestRerunConfig: + def test_static_pinhole_returns_list(self): + import rerun as rr + + result = UnityBridgeModule.rerun_static_pinhole(rr) + assert isinstance(result, list) + assert len(result) == 2 + + def test_suppress_returns_none(self): + assert UnityBridgeModule.rerun_suppress_camera_info(None) is None + + +# --------------------------------------------------------------------------- +# Live Unity — slow, requires Linux x86_64 + DISPLAY +# These are skipped in CI and on unsupported platforms. +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.skipif(not _is_linux_x86, reason="Unity binary requires Linux x86_64") +@pytest.mark.skipif(not _has_display, reason="Unity requires DISPLAY (X11)") +class TestLiveUnity: + """Tests that launch the actual Unity binary. Skipped unless on Linux x86_64 with a display.""" + + def test_unity_connects_and_streams(self): + """Launch Unity, verify it connects and sends lidar + images.""" + m = UnityBridgeModule() # uses auto-download + ts = _wire(m) + + m.start() + time.sleep(25) + + assert m._unity_connected, "Unity did not connect" + assert len(ts["registered_scan"]._messages) > 5, "No lidar from Unity" + assert len(ts["color_image"]._messages) > 5, "No camera images from Unity" + assert len(ts["odometry"]._messages) > 100, "No odometry" + + m.stop() diff --git a/dimos/utils/ros1.py b/dimos/utils/ros1.py new file mode 100644 index 0000000000..e2d993d851 --- /dev/null +++ b/dimos/utils/ros1.py @@ -0,0 +1,397 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROS1 binary message deserialization — no ROS1 installation required. + +Implements pure-Python deserialization of standard ROS1 message types from their +binary wire format (as used by the Unity ROS-TCP-Connector). These messages use +little-endian encoding with uint32-length-prefixed strings and arrays. + +Wire format basics: + - Primitive types: packed directly (e.g. uint32 = 4 bytes LE) + - Strings: uint32 length + N bytes (no null terminator in wire format) + - Arrays: uint32 count + N * element_size bytes + - Time: uint32 sec + uint32 nsec + - Nested messages: serialized inline (no length prefix for fixed-size) + +Supported types: + - sensor_msgs/PointCloud2 + - sensor_msgs/CompressedImage + - geometry_msgs/PoseStamped (serialize + deserialize) + - geometry_msgs/TwistStamped (serialize) + - nav_msgs/Odometry (deserialize) +""" + +from __future__ import annotations + +from dataclasses import dataclass +import struct +import time +from typing import Any + +import numpy as np + +# --------------------------------------------------------------------------- +# Low-level readers +# --------------------------------------------------------------------------- + + +class ROS1Reader: + """Stateful reader for ROS1 binary serialized data.""" + + __slots__ = ("data", "off") + + def __init__(self, data: bytes) -> None: + self.data = data + self.off = 0 + + def u8(self) -> int: + v = self.data[self.off] + self.off += 1 + return v + + def bool(self) -> bool: + return self.u8() != 0 + + def u32(self) -> int: + (v,) = struct.unpack_from(" int: + (v,) = struct.unpack_from(" float: + (v,) = struct.unpack_from(" float: + (v,) = struct.unpack_from(" str: + length = self.u32() + s = self.data[self.off : self.off + length].decode("utf-8", errors="replace") + self.off += length + return s + + def time(self) -> float: + """Read ROS1 time (uint32 sec + uint32 nsec) → float seconds.""" + sec = self.u32() + nsec = self.u32() + return sec + nsec / 1e9 + + def raw(self, n: int) -> bytes: + v = self.data[self.off : self.off + n] + self.off += n + return v + + def remaining(self) -> int: + return len(self.data) - self.off + + +# --------------------------------------------------------------------------- +# Low-level writer +# --------------------------------------------------------------------------- + + +class ROS1Writer: + """Stateful writer for ROS1 binary serialized data.""" + + def __init__(self) -> None: + self.buf = bytearray() + + def u8(self, v: int) -> None: + self.buf.append(v & 0xFF) + + def bool(self, v: bool) -> None: + self.u8(1 if v else 0) + + def u32(self, v: int) -> None: + self.buf += struct.pack(" None: + self.buf += struct.pack(" None: + self.buf += struct.pack(" None: + self.buf += struct.pack(" None: + b = s.encode("utf-8") + self.u32(len(b)) + self.buf += b + + def time(self, t: float | None = None) -> None: + if t is None: + t = time.time() + sec = int(t) + nsec = int((t - sec) * 1e9) + self.u32(sec) + self.u32(nsec) + + def raw(self, data: bytes) -> None: + self.buf += data + + def bytes(self) -> bytes: + return bytes(self.buf) + + +# --------------------------------------------------------------------------- +# Header (std_msgs/Header) +# --------------------------------------------------------------------------- + + +@dataclass +class ROS1Header: + seq: int = 0 + stamp: float = 0.0 # seconds + frame_id: str = "" + + +def read_header(r: ROS1Reader) -> ROS1Header: + seq = r.u32() + stamp = r.time() + frame_id = r.string() + return ROS1Header(seq, stamp, frame_id) + + +def write_header( + w: ROS1Writer, frame_id: str = "map", stamp: float | None = None, seq: int = 0 +) -> None: + w.u32(seq) + w.time(stamp) + w.string(frame_id) + + +# --------------------------------------------------------------------------- +# sensor_msgs/PointCloud2 +# --------------------------------------------------------------------------- + + +@dataclass +class ROS1PointField: + name: str + offset: int + datatype: int # 7=FLOAT32, 8=FLOAT64, etc. + count: int + + +def deserialize_pointcloud2(data: bytes) -> tuple[np.ndarray, str, float] | None: + """Deserialize ROS1 sensor_msgs/PointCloud2 → (Nx3 float32 points, frame_id, timestamp). + + Returns None on parse failure. + """ + try: + r = ROS1Reader(data) + header = read_header(r) + + height = r.u32() + width = r.u32() + num_points = height * width + + # PointField array + num_fields = r.u32() + x_off = y_off = z_off = -1 + for _ in range(num_fields): + name = r.string() + offset = r.u32() + r.u8() + r.u32() + if name == "x": + x_off = offset + elif name == "y": + y_off = offset + elif name == "z": + z_off = offset + + r.bool() + point_step = r.u32() + r.u32() + + # Data array + data_len = r.u32() + raw_data = r.raw(data_len) + + # is_dense + if r.remaining() > 0: + r.bool() + + if x_off < 0 or y_off < 0 or z_off < 0: + return None + if num_points == 0: + return np.zeros((0, 3), dtype=np.float32), header.frame_id, header.stamp + + # Fast path: standard XYZI layout + if x_off == 0 and y_off == 4 and z_off == 8 and point_step >= 12: + if point_step == 12: + points = ( + np.frombuffer(raw_data, dtype=np.float32, count=num_points * 3) + .reshape(-1, 3) + .copy() + ) + else: + dt = np.dtype( + [("x", " tuple[bytes, str, str, float] | None: + """Deserialize ROS1 sensor_msgs/CompressedImage → (raw_data, format, frame_id, timestamp). + + The raw_data is JPEG/PNG bytes that can be decoded with cv2.imdecode or PIL. + Returns None on parse failure. + """ + try: + r = ROS1Reader(data) + header = read_header(r) + fmt = r.string() # e.g. "jpeg", "png" + img_len = r.u32() + img_data = r.raw(img_len) + return img_data, fmt, header.frame_id, header.stamp + except Exception: + return None + + +# --------------------------------------------------------------------------- +# geometry_msgs/PoseStamped (serialize) +# --------------------------------------------------------------------------- + + +def serialize_pose_stamped( + x: float, + y: float, + z: float, + qx: float, + qy: float, + qz: float, + qw: float, + frame_id: str = "map", + stamp: float | None = None, +) -> bytes: + """Serialize geometry_msgs/PoseStamped in ROS1 wire format.""" + w = ROS1Writer() + write_header(w, frame_id, stamp) + # Pose: position (3x f64) + orientation (4x f64) + w.f64(x) + w.f64(y) + w.f64(z) + w.f64(qx) + w.f64(qy) + w.f64(qz) + w.f64(qw) + return w.bytes() + + +# --------------------------------------------------------------------------- +# geometry_msgs/TwistStamped (serialize) +# --------------------------------------------------------------------------- + + +def serialize_twist_stamped( + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + frame_id: str = "base_link", + stamp: float | None = None, +) -> bytes: + """Serialize geometry_msgs/TwistStamped in ROS1 wire format.""" + w = ROS1Writer() + write_header(w, frame_id, stamp) + # Twist: linear (3x f64) + angular (3x f64) + w.f64(linear_x) + w.f64(linear_y) + w.f64(linear_z) + w.f64(angular_x) + w.f64(angular_y) + w.f64(angular_z) + return w.bytes() + + +# --------------------------------------------------------------------------- +# nav_msgs/Odometry (deserialize) +# --------------------------------------------------------------------------- + + +def deserialize_odometry(data: bytes) -> tuple[dict[str, Any], str, str, float] | None: + """Deserialize ROS1 nav_msgs/Odometry. + + Returns (pose_dict, frame_id, child_frame_id, timestamp) or None. + pose_dict has keys: x, y, z, qx, qy, qz, qw, vx, vy, vz, wx, wy, wz + """ + try: + r = ROS1Reader(data) + header = read_header(r) + child_frame_id = r.string() + + # PoseWithCovariance: Pose (Point + Quaternion) + float64[36] + x, y, z = r.f64(), r.f64(), r.f64() + qx, qy, qz, qw = r.f64(), r.f64(), r.f64(), r.f64() + r.raw(36 * 8) # skip covariance + + # TwistWithCovariance: Twist (Vector3 + Vector3) + float64[36] + vx, vy, vz = r.f64(), r.f64(), r.f64() + wx, wy, wz = r.f64(), r.f64(), r.f64() + r.raw(36 * 8) # skip covariance + + return ( + { + "x": x, + "y": y, + "z": z, + "qx": qx, + "qy": qy, + "qz": qz, + "qw": qw, + "vx": vx, + "vy": vy, + "vz": vz, + "wx": wx, + "wy": wy, + "wz": wz, + }, + header.frame_id, + child_frame_id, + header.stamp, + ) + except Exception: + return None From e135fa2a698727cca64f192829967d73b29b132a Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 13:47:47 -0700 Subject: [PATCH 03/42] clean up --- dimos/simulation/unity/module.py | 7 ++- dimos/utils/ros1.py | 79 -------------------------------- 2 files changed, 5 insertions(+), 81 deletions(-) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index f1933cbde3..9e767f4421 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -251,8 +251,6 @@ class UnityBridgeConfig(ModuleConfig): unity_extra_args: list[str] = field(default_factory=list) # Vehicle parameters - sensor_offset_x: float = 0.0 - sensor_offset_y: float = 0.0 vehicle_height: float = 0.75 # Initial vehicle pose @@ -637,6 +635,11 @@ def _handle_unity_message(self, topic: str, data: bytes) -> None: logger.warning(f"Image decode failed ({topic}): {e}") def _publish_camera_info(self, width: int, height: int, ts: float) -> None: + # NOTE: The Unity camera is a 360-degree cylindrical panorama (1920x640). + # CameraInfo assumes a pinhole model, so this is an approximation. + # The Rerun static pinhole (rerun_static_pinhole) uses a different focal + # length tuned for a 120-deg FOV window because Rerun has no cylindrical + # projection support. These intentionally differ. fx = fy = height / 2.0 cx, cy = width / 2.0, height / 2.0 self.camera_info.publish( diff --git a/dimos/utils/ros1.py b/dimos/utils/ros1.py index e2d993d851..b3c6c43456 100644 --- a/dimos/utils/ros1.py +++ b/dimos/utils/ros1.py @@ -38,7 +38,6 @@ from dataclasses import dataclass import struct import time -from typing import Any import numpy as np @@ -317,81 +316,3 @@ def serialize_pose_stamped( w.f64(qz) w.f64(qw) return w.bytes() - - -# --------------------------------------------------------------------------- -# geometry_msgs/TwistStamped (serialize) -# --------------------------------------------------------------------------- - - -def serialize_twist_stamped( - linear_x: float, - linear_y: float, - linear_z: float, - angular_x: float, - angular_y: float, - angular_z: float, - frame_id: str = "base_link", - stamp: float | None = None, -) -> bytes: - """Serialize geometry_msgs/TwistStamped in ROS1 wire format.""" - w = ROS1Writer() - write_header(w, frame_id, stamp) - # Twist: linear (3x f64) + angular (3x f64) - w.f64(linear_x) - w.f64(linear_y) - w.f64(linear_z) - w.f64(angular_x) - w.f64(angular_y) - w.f64(angular_z) - return w.bytes() - - -# --------------------------------------------------------------------------- -# nav_msgs/Odometry (deserialize) -# --------------------------------------------------------------------------- - - -def deserialize_odometry(data: bytes) -> tuple[dict[str, Any], str, str, float] | None: - """Deserialize ROS1 nav_msgs/Odometry. - - Returns (pose_dict, frame_id, child_frame_id, timestamp) or None. - pose_dict has keys: x, y, z, qx, qy, qz, qw, vx, vy, vz, wx, wy, wz - """ - try: - r = ROS1Reader(data) - header = read_header(r) - child_frame_id = r.string() - - # PoseWithCovariance: Pose (Point + Quaternion) + float64[36] - x, y, z = r.f64(), r.f64(), r.f64() - qx, qy, qz, qw = r.f64(), r.f64(), r.f64(), r.f64() - r.raw(36 * 8) # skip covariance - - # TwistWithCovariance: Twist (Vector3 + Vector3) + float64[36] - vx, vy, vz = r.f64(), r.f64(), r.f64() - wx, wy, wz = r.f64(), r.f64(), r.f64() - r.raw(36 * 8) # skip covariance - - return ( - { - "x": x, - "y": y, - "z": z, - "qx": qx, - "qy": qy, - "qz": qz, - "qw": qw, - "vx": vx, - "vy": vy, - "vz": vz, - "wx": wx, - "wy": wy, - "wz": wz, - }, - header.frame_id, - child_frame_id, - header.stamp, - ) - except Exception: - return None From c3cf3e64c18e5092ea13b9cc62658d8710664be4 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 13:56:14 -0700 Subject: [PATCH 04/42] cleaning --- dimos/robot/all_blueprints.py | 2 +- dimos/simulation/unity/__main__.py | 20 -------------------- dimos/simulation/unity/blueprint.py | 12 ++---------- 3 files changed, 3 insertions(+), 31 deletions(-) delete mode 100644 dimos/simulation/unity/__main__.py diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 45808971e7..f576fcbc2b 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -86,7 +86,7 @@ "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", - "unity-sim-blueprint": "dimos.simulation.unity.blueprint:unity_sim_blueprint", + "unity-sim": "dimos.simulation.unity.blueprint:unity_sim", "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", diff --git a/dimos/simulation/unity/__main__.py b/dimos/simulation/unity/__main__.py deleted file mode 100644 index 0e25da7ed7..0000000000 --- a/dimos/simulation/unity/__main__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run the standalone Unity sim blueprint: python -m dimos.simulation.unity""" - -from dimos.simulation.unity.blueprint import main - -if __name__ == "__main__": - main() diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index 5f5dd139fd..cceb3e697e 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -18,7 +18,7 @@ keyboard teleop via TUI. No navigation stack — just raw sim data. Usage: - python -m dimos.simulation.unity.blueprint + dimos run unity-sim """ from __future__ import annotations @@ -55,15 +55,7 @@ def _rerun_blueprint() -> Any: } -unity_sim_blueprint = autoconnect( +unity_sim = autoconnect( UnityBridgeModule.blueprint(), rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) - - -def main() -> None: - unity_sim_blueprint.build().loop() - - -if __name__ == "__main__": - main() From d511bf97e147a19884d868a762b19284962be3a4 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 14:33:48 -0700 Subject: [PATCH 05/42] improve binary downloading (google drive) --- .gitignore | 3 + dimos/simulation/unity/blueprint.py | 4 +- dimos/simulation/unity/module.py | 121 +++++++++++++++++------ dimos/simulation/unity/test_unity_sim.py | 8 +- 4 files changed, 100 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 4045db012e..21bac35ead 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,9 @@ CLAUDE.MD /.mcp.json *.speedscope.json +# Unity simulator cache +.unity_envs/ + # Coverage htmlcov/ .coverage diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index cceb3e697e..36d2005fcb 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -27,7 +27,7 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.simulation.unity.module import UnityBridgeModule +from dimos.simulation.unity.module import UnityBridgeModule, resolve_unity_binary from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge @@ -58,4 +58,4 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config), -) +).requirements(resolve_unity_binary()) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 9e767f4421..3fec1ad7b1 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -41,7 +41,7 @@ import subprocess import threading import time -from typing import Any +from typing import TYPE_CHECKING, Any import zipfile import numpy as np @@ -66,12 +66,15 @@ serialize_pose_stamped, ) +if TYPE_CHECKING: + from collections.abc import Callable + logger = setup_logger() PI = math.pi # Google Drive folder containing environment zips _GDRIVE_FOLDER_ID = "1UD5v6cSfcwIMWmsq9WSk7blJut4kgb-1" -_DEFAULT_SCENE = "office_1" +_DEFAULT_SCENE = "japanese_room_1" _SUPPORTED_SYSTEMS = {"Linux"} _SUPPORTED_ARCHS = {"x86_64", "AMD64"} @@ -143,18 +146,35 @@ def _download_unity_scene(scene: str, dest_dir: Path) -> Path: zip_path = dest_dir / f"{scene}.zip" if not zip_path.exists(): - print("\n" + "=" * 70, flush=True) - print(f" DOWNLOADING UNITY SIMULATOR — scene: '{scene}'", flush=True) - print(" Source: Google Drive (VLA Challenge environments)", flush=True) - print(" Size: ~130-580 MB per scene (depends on scene complexity)", flush=True) - print(f" Destination: {dest_dir}", flush=True) - print(" This is a one-time download. Subsequent runs use the cache.", flush=True) - print("=" * 70 + "\n", flush=True) + print(flush=True) + print("=" * 70, flush=True) + print("", flush=True) + print(" UNITY SIMULATOR DOWNLOAD", flush=True) + print("", flush=True) + print(f" The Unity simulator scene '{scene}' was not found locally.", flush=True) + print(" Downloading it now from Google Drive. This is a ONE-TIME", flush=True) + print(" download — future runs will use the cached binary.", flush=True) + print("", flush=True) + print(" Source: VLA Challenge (Google Drive)", flush=True) + print(f" Scene: {scene}", flush=True) + print(" Size: ~130-580 MB (depends on scene)", flush=True) + print(f" Cache: {dest_dir}", flush=True) + print("", flush=True) + print(" gdown will print progress below. This may take a few", flush=True) + print(" minutes depending on your connection speed.", flush=True) + print("", flush=True) + print("=" * 70, flush=True) + print(flush=True) gdown.download_folder( id=_GDRIVE_FOLDER_ID, output=str(dest_dir), quiet=False, ) + print(flush=True) + print("=" * 70, flush=True) + print(" Download complete. Locating scene zip...", flush=True) + print("=" * 70, flush=True) + print(flush=True) # gdown downloads all scenes into a subfolder; find our zip for candidate in dest_dir.rglob(f"{scene}.zip"): zip_path = candidate @@ -169,9 +189,10 @@ def _download_unity_scene(scene: str, dest_dir: Path) -> Path: # Extract extract_dir = dest_dir / scene if not extract_dir.exists(): - logger.info(f"Extracting {zip_path}...") + print(f" Extracting {zip_path.name}...", flush=True) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(dest_dir) + print(" Extraction complete.", flush=True) binary = extract_dir / "environment" / "Model.x86_64" if not binary.exists(): @@ -208,6 +229,47 @@ def _validate_platform() -> None: ) +# --------------------------------------------------------------------------- +# Host-side binary resolution (runs BEFORE worker deploy) +# --------------------------------------------------------------------------- + + +def resolve_unity_binary( + scene: str = _DEFAULT_SCENE, + cache_dir: str = ".unity_envs", + auto_download: bool = True, +) -> Callable[[], str | None]: + """Return a blueprint requirement check that resolves the Unity binary. + + This runs on the HOST process during blueprint.build(), before modules + are deployed to worker subprocesses. If the binary is not cached and + auto_download is True, it downloads the scene from Google Drive. + + Usage in a blueprint:: + + unity_sim = autoconnect( + UnityBridgeModule.blueprint(), + ... + ).requirements(resolve_unity_binary()) + """ + + def _check() -> str | None: + cache = Path(cache_dir).expanduser() + candidate = cache / scene / "environment" / "Model.x86_64" + if candidate.exists(): + return None # already cached, no error + + if not auto_download: + return f"Unity scene '{scene}' not found at {candidate} and auto_download is disabled." + + _validate_platform() + logger.info(f"Unity binary not found, downloading scene '{scene}'...") + _download_unity_scene(scene, cache) + return None # success + + return _check + + # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @@ -230,8 +292,8 @@ class UnityBridgeConfig(ModuleConfig): # Only used when unity_binary is not found and auto_download is True. unity_scene: str = _DEFAULT_SCENE - # Directory to download/cache Unity scenes. - unity_cache_dir: str = "~/.cache/smartnav/unity_envs" + # Directory to download/cache Unity scenes (relative to cwd). + unity_cache_dir: str = ".unity_envs" # Auto-download the scene from Google Drive if binary is missing. auto_download: bool = True @@ -349,6 +411,7 @@ def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] self._unity_ready = threading.Event() self._unity_process: subprocess.Popen | None = None # type: ignore[type-arg] self._send_queue: Queue[tuple[str, bytes]] = Queue() + self._binary_path = self._resolve_binary() def __getstate__(self) -> dict[str, Any]: # type: ignore[override] state: dict[str, Any] = super().__getstate__() # type: ignore[no-untyped-call] @@ -374,6 +437,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self._send_queue = Queue() self._unity_ready = threading.Event() self._running = False + self._binary_path = self._resolve_binary() @rpc def start(self) -> None: @@ -409,7 +473,12 @@ def stop(self) -> None: # ---- Unity process management ----------------------------------------- def _resolve_binary(self) -> Path | None: - """Find the Unity binary, downloading if needed. Returns None to skip launch.""" + """Find the Unity binary from config or cache. Does NOT download. + + Downloads happen on the HOST via resolve_unity_binary() (called from + the blueprint requirement hook) before the module is deployed to a + worker subprocess. + """ cfg = self.config # Explicit path provided @@ -421,28 +490,20 @@ def _resolve_binary(self) -> Path | None: p = (Path(__file__).resolve().parent / cfg.unity_binary).resolve() if p.exists(): return p - if not cfg.auto_download: - logger.error( - f"Unity binary not found at {p} and auto_download is disabled. " - f"Set unity_binary to a valid path or enable auto_download." - ) - return None - - # Auto-download - if cfg.auto_download: - _validate_platform() - cache = Path(cfg.unity_cache_dir).expanduser() - candidate = cache / cfg.unity_scene / "environment" / "Model.x86_64" - if candidate.exists(): - return candidate - logger.info(f"Unity binary not found, downloading scene '{cfg.unity_scene}'...") - return _download_unity_scene(cfg.unity_scene, cache) + logger.warning(f"Unity binary not found at {p}") + return None + + # Check cache (download already happened on host) + cache = Path(cfg.unity_cache_dir).expanduser() + candidate = cache / cfg.unity_scene / "environment" / "Model.x86_64" + if candidate.exists(): + return candidate return None def _launch_unity(self) -> None: """Launch the Unity simulator binary as a subprocess.""" - binary_path = self._resolve_binary() + binary_path = self._binary_path if binary_path is None: logger.info("No Unity binary — TCP server will wait for external connection") return diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index fb8bf0b7a5..4cb889b9e6 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -172,7 +172,7 @@ def test_rejects_unsupported_platform(self): class TestPickle: def test_module_survives_pickle(self): - m = UnityBridgeModule(unity_binary="") + m = UnityBridgeModule(unity_binary="", auto_download=False) m2 = pickle.loads(pickle.dumps(m)) assert hasattr(m2, "_cmd_lock") assert m2._running is False @@ -205,7 +205,7 @@ class TestTCPBridge: def test_handshake_and_data_flow(self): """Mock Unity connects, sends a PointCloud2, verifies bridge publishes it.""" port = _find_free_port() - m = UnityBridgeModule(unity_binary="", unity_port=port) + m = UnityBridgeModule(unity_binary="", auto_download=False, unity_port=port) ts = _wire(m) m._running = True @@ -240,7 +240,7 @@ def test_handshake_and_data_flow(self): class TestKinematicSim: def test_odometry_published(self): - m = UnityBridgeModule(unity_binary="", sim_rate=100.0) + m = UnityBridgeModule(unity_binary="", auto_download=False, sim_rate=100.0) ts = _wire(m) m._running = True @@ -255,7 +255,7 @@ def test_odometry_published(self): assert ts["odometry"]._messages[0].frame_id == "map" def test_cmd_vel_moves_robot(self): - m = UnityBridgeModule(unity_binary="", sim_rate=200.0) + m = UnityBridgeModule(unity_binary="", auto_download=False, sim_rate=200.0) ts = _wire(m) m._on_cmd_vel(Twist(linear=[1.0, 0.0, 0.0], angular=[0.0, 0.0, 0.0])) From d3356856ed8151bd664478537f8b6697b72accb3 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 12 Mar 2026 14:53:52 -0700 Subject: [PATCH 06/42] feat(unity-sim): use LFS for sim binary, remove Google Drive download Replace gdown/Google Drive auto-download with get_data() LFS asset (unity_sim_x86, 128MB compressed). Simplify config by removing unity_scene, unity_cache_dir, auto_download fields. Clean up blueprint (remove __main__.py, rename to unity_sim, remove resolve_unity_binary requirement hook). --- .gitignore | 3 - data/.lfs/unity_sim_x86.tar.gz | 3 + dimos/simulation/unity/blueprint.py | 4 +- dimos/simulation/unity/module.py | 178 +++-------------------- dimos/simulation/unity/test_unity_sim.py | 9 +- 5 files changed, 31 insertions(+), 166 deletions(-) create mode 100644 data/.lfs/unity_sim_x86.tar.gz diff --git a/.gitignore b/.gitignore index 21bac35ead..4045db012e 100644 --- a/.gitignore +++ b/.gitignore @@ -73,9 +73,6 @@ CLAUDE.MD /.mcp.json *.speedscope.json -# Unity simulator cache -.unity_envs/ - # Coverage htmlcov/ .coverage diff --git a/data/.lfs/unity_sim_x86.tar.gz b/data/.lfs/unity_sim_x86.tar.gz new file mode 100644 index 0000000000..00212578a9 --- /dev/null +++ b/data/.lfs/unity_sim_x86.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b02bb692abceedb05e5d85efc0f9c1b1f0d605b4ae011c1a98d35c64036abc11 +size 133299059 diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index 36d2005fcb..cceb3e697e 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -27,7 +27,7 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.simulation.unity.module import UnityBridgeModule, resolve_unity_binary +from dimos.simulation.unity.module import UnityBridgeModule from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge @@ -58,4 +58,4 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config), -).requirements(resolve_unity_binary()) +) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 3fec1ad7b1..d16420495a 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -41,8 +41,7 @@ import subprocess import threading import time -from typing import TYPE_CHECKING, Any -import zipfile +from typing import Any import numpy as np from reactivex.disposable import Disposable @@ -59,6 +58,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger from dimos.utils.ros1 import ( deserialize_compressed_image, @@ -66,15 +66,11 @@ serialize_pose_stamped, ) -if TYPE_CHECKING: - from collections.abc import Callable - logger = setup_logger() PI = math.pi -# Google Drive folder containing environment zips -_GDRIVE_FOLDER_ID = "1UD5v6cSfcwIMWmsq9WSk7blJut4kgb-1" -_DEFAULT_SCENE = "japanese_room_1" +# LFS data asset name for the Unity sim binary +_LFS_ASSET = "unity_sim_x86" _SUPPORTED_SYSTEMS = {"Linux"} _SUPPORTED_ARCHS = {"x86_64", "AMD64"} @@ -122,89 +118,6 @@ def _write_tcp_command(sock: socket.socket, command: str, params: dict[str, Any] ) -# --------------------------------------------------------------------------- -# Auto-download -# --------------------------------------------------------------------------- - - -def _download_unity_scene(scene: str, dest_dir: Path) -> Path: - """Download a Unity environment zip from Google Drive and extract it. - - Returns the path to the Model.x86_64 binary. - """ - try: - import gdown # type: ignore[import-untyped] - except ImportError: - raise RuntimeError( - "Unity sim binary not found and 'gdown' is not installed for auto-download. " - "Install it with: pip install gdown\n" - "Or manually download from: " - f"https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" - ) - - dest_dir.mkdir(parents=True, exist_ok=True) - zip_path = dest_dir / f"{scene}.zip" - - if not zip_path.exists(): - print(flush=True) - print("=" * 70, flush=True) - print("", flush=True) - print(" UNITY SIMULATOR DOWNLOAD", flush=True) - print("", flush=True) - print(f" The Unity simulator scene '{scene}' was not found locally.", flush=True) - print(" Downloading it now from Google Drive. This is a ONE-TIME", flush=True) - print(" download — future runs will use the cached binary.", flush=True) - print("", flush=True) - print(" Source: VLA Challenge (Google Drive)", flush=True) - print(f" Scene: {scene}", flush=True) - print(" Size: ~130-580 MB (depends on scene)", flush=True) - print(f" Cache: {dest_dir}", flush=True) - print("", flush=True) - print(" gdown will print progress below. This may take a few", flush=True) - print(" minutes depending on your connection speed.", flush=True) - print("", flush=True) - print("=" * 70, flush=True) - print(flush=True) - gdown.download_folder( - id=_GDRIVE_FOLDER_ID, - output=str(dest_dir), - quiet=False, - ) - print(flush=True) - print("=" * 70, flush=True) - print(" Download complete. Locating scene zip...", flush=True) - print("=" * 70, flush=True) - print(flush=True) - # gdown downloads all scenes into a subfolder; find our zip - for candidate in dest_dir.rglob(f"{scene}.zip"): - zip_path = candidate - break - - if not zip_path.exists(): - raise FileNotFoundError( - f"Failed to download scene '{scene}'. " - f"Check https://drive.google.com/drive/folders/{_GDRIVE_FOLDER_ID}" - ) - - # Extract - extract_dir = dest_dir / scene - if not extract_dir.exists(): - print(f" Extracting {zip_path.name}...", flush=True) - with zipfile.ZipFile(zip_path, "r") as zf: - zf.extractall(dest_dir) - print(" Extraction complete.", flush=True) - - binary = extract_dir / "environment" / "Model.x86_64" - if not binary.exists(): - raise FileNotFoundError( - f"Extracted scene but Model.x86_64 not found at {binary}. " - f"Expected structure: {scene}/environment/Model.x86_64" - ) - - binary.chmod(binary.stat().st_mode | 0o111) - return binary - - # --------------------------------------------------------------------------- # Platform validation # --------------------------------------------------------------------------- @@ -229,47 +142,6 @@ def _validate_platform() -> None: ) -# --------------------------------------------------------------------------- -# Host-side binary resolution (runs BEFORE worker deploy) -# --------------------------------------------------------------------------- - - -def resolve_unity_binary( - scene: str = _DEFAULT_SCENE, - cache_dir: str = ".unity_envs", - auto_download: bool = True, -) -> Callable[[], str | None]: - """Return a blueprint requirement check that resolves the Unity binary. - - This runs on the HOST process during blueprint.build(), before modules - are deployed to worker subprocesses. If the binary is not cached and - auto_download is True, it downloads the scene from Google Drive. - - Usage in a blueprint:: - - unity_sim = autoconnect( - UnityBridgeModule.blueprint(), - ... - ).requirements(resolve_unity_binary()) - """ - - def _check() -> str | None: - cache = Path(cache_dir).expanduser() - candidate = cache / scene / "environment" / "Model.x86_64" - if candidate.exists(): - return None # already cached, no error - - if not auto_download: - return f"Unity scene '{scene}' not found at {candidate} and auto_download is disabled." - - _validate_platform() - logger.info(f"Unity binary not found, downloading scene '{scene}'...") - _download_unity_scene(scene, cache) - return None # success - - return _check - - # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @@ -279,25 +151,15 @@ def _check() -> str | None: class UnityBridgeConfig(ModuleConfig): """Configuration for the Unity bridge / vehicle simulator. - Set ``unity_binary=""`` to skip launching Unity and connect to an - already-running instance. Set ``auto_download=True`` (default) to - automatically download the scene if the binary is missing. + Set ``unity_binary=""`` to auto-resolve from LFS data (default). + Set to an explicit path to use a custom binary. The LFS asset + ``unity_sim_x86`` is pulled automatically via ``get_data()``. """ - # Path to the Unity x86_64 binary. Relative paths resolved from cwd. - # Leave empty to auto-detect from cache or auto-download. + # Path to the Unity x86_64 binary. Leave empty to auto-resolve + # from LFS data (unity_sim_x86/environment/Model.x86_64). unity_binary: str = "" - # Scene name for auto-download (e.g. "office_1", "hotel_room_1"). - # Only used when unity_binary is not found and auto_download is True. - unity_scene: str = _DEFAULT_SCENE - - # Directory to download/cache Unity scenes (relative to cwd). - unity_cache_dir: str = ".unity_envs" - - # Auto-download the scene from Google Drive if binary is missing. - auto_download: bool = True - # Max seconds to wait for Unity to connect after launch. unity_connect_timeout: float = 30.0 @@ -473,11 +335,11 @@ def stop(self) -> None: # ---- Unity process management ----------------------------------------- def _resolve_binary(self) -> Path | None: - """Find the Unity binary from config or cache. Does NOT download. + """Find the Unity binary from config or LFS data. - Downloads happen on the HOST via resolve_unity_binary() (called from - the blueprint requirement hook) before the module is deployed to a - worker subprocess. + When ``unity_binary`` is empty (default), pulls the LFS asset + ``unity_sim_x86`` via ``get_data()`` and returns the path to + ``environment/Model.x86_64``. """ cfg = self.config @@ -493,11 +355,15 @@ def _resolve_binary(self) -> Path | None: logger.warning(f"Unity binary not found at {p}") return None - # Check cache (download already happened on host) - cache = Path(cfg.unity_cache_dir).expanduser() - candidate = cache / cfg.unity_scene / "environment" / "Model.x86_64" - if candidate.exists(): - return candidate + # Pull from LFS (auto-downloads + extracts on first use) + try: + data_dir = get_data(_LFS_ASSET) + candidate = data_dir / "environment" / "Model.x86_64" + if candidate.exists(): + return candidate + logger.warning(f"LFS asset '{_LFS_ASSET}' extracted but Model.x86_64 not found") + except Exception as e: + logger.warning(f"Failed to resolve Unity binary from LFS: {e}") return None diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index 4cb889b9e6..31f1237f51 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -141,7 +141,6 @@ def test_default_config(self): cfg = UnityBridgeConfig() assert cfg.unity_port == 10000 assert cfg.sim_rate == 200.0 - assert cfg.auto_download is True def test_custom_binary_path(self): cfg = UnityBridgeConfig(unity_binary="/custom/path/Model.x86_64") @@ -172,7 +171,7 @@ def test_rejects_unsupported_platform(self): class TestPickle: def test_module_survives_pickle(self): - m = UnityBridgeModule(unity_binary="", auto_download=False) + m = UnityBridgeModule(unity_binary="") m2 = pickle.loads(pickle.dumps(m)) assert hasattr(m2, "_cmd_lock") assert m2._running is False @@ -205,7 +204,7 @@ class TestTCPBridge: def test_handshake_and_data_flow(self): """Mock Unity connects, sends a PointCloud2, verifies bridge publishes it.""" port = _find_free_port() - m = UnityBridgeModule(unity_binary="", auto_download=False, unity_port=port) + m = UnityBridgeModule(unity_binary="", unity_port=port) ts = _wire(m) m._running = True @@ -240,7 +239,7 @@ def test_handshake_and_data_flow(self): class TestKinematicSim: def test_odometry_published(self): - m = UnityBridgeModule(unity_binary="", auto_download=False, sim_rate=100.0) + m = UnityBridgeModule(unity_binary="", sim_rate=100.0) ts = _wire(m) m._running = True @@ -255,7 +254,7 @@ def test_odometry_published(self): assert ts["odometry"]._messages[0].frame_id == "map" def test_cmd_vel_moves_robot(self): - m = UnityBridgeModule(unity_binary="", auto_download=False, sim_rate=200.0) + m = UnityBridgeModule(unity_binary="", sim_rate=200.0) ts = _wire(m) m._on_cmd_vel(Twist(linear=[1.0, 0.0, 0.0], angular=[0.0, 0.0, 0.0])) From 23e81f5a1551037ae81b4b971bf53339698bb752 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Thu, 12 Mar 2026 23:02:02 +0000 Subject: [PATCH 07/42] Module config tweaks (#1510) * Reapply "Module config adjustments (#1413)" (#1417) This reverts commit 3df8857130deeb4c04aa71200db7603f3d2790d2. * Fixes * Move global_config * Why did this not commit? * Default * Fix * Fix * Fix * Docs * tentative fix for timeout error * Fix * Use create_autospec * TemporalMemory fix * Forbid extra * Fix --------- Co-authored-by: Sam Bull Co-authored-by: Paul Nechifor --- dimos/agents/agent.py | 6 +- dimos/agents/agent_test_runner.py | 19 ++-- dimos/agents/mcp/mcp_client.py | 6 +- dimos/agents/mcp/mcp_server.py | 16 ++- dimos/agents/mcp/test_mcp_client.py | 11 +- .../skills/google_maps_skill_container.py | 4 +- dimos/agents/skills/gps_nav_skill.py | 3 - dimos/agents/skills/navigation.py | 4 +- dimos/agents/skills/person_follow.py | 43 ++++---- .../test_google_maps_skill_container.py | 9 +- dimos/agents/skills/test_gps_nav_skills.py | 8 +- dimos/agents/skills/test_navigation.py | 15 +-- .../skills/test_unitree_skill_container.py | 5 +- dimos/agents/test_agent.py | 11 +- dimos/agents/vlm_agent.py | 11 +- dimos/agents_deprecated/modules/base_agent.py | 58 +++++----- dimos/control/coordinator.py | 1 - dimos/core/blueprints.py | 57 +++++----- dimos/core/docker_runner.py | 3 +- dimos/core/introspection/blueprint/dot.py | 10 +- dimos/core/module.py | 54 ++++++---- dimos/core/module_coordinator.py | 23 ++-- dimos/core/native_module.py | 41 ++++--- dimos/core/test_blueprints.py | 9 +- dimos/core/test_core.py | 3 - dimos/core/test_native_module.py | 2 - dimos/core/test_stream.py | 9 +- dimos/core/test_worker.py | 21 ++-- dimos/core/testing.py | 5 +- dimos/core/worker.py | 28 ++--- dimos/core/worker_manager.py | 30 +++--- .../camera/gstreamer/gstreamer_camera.py | 31 +++--- dimos/hardware/sensors/camera/module.py | 17 +-- .../sensors/camera/realsense/camera.py | 8 +- dimos/hardware/sensors/camera/spec.py | 10 +- dimos/hardware/sensors/camera/zed/__init__.py | 8 +- dimos/hardware/sensors/camera/zed/camera.py | 8 +- dimos/hardware/sensors/camera/zed/test_zed.py | 7 +- dimos/hardware/sensors/fake_zed_module.py | 9 +- .../hardware/sensors/lidar/fastlio2/module.py | 29 +++-- dimos/hardware/sensors/lidar/livox/module.py | 10 +- .../cartesian_motion_controller.py | 13 ++- .../joint_trajectory_controller.py | 9 +- .../manipulation/grasping/graspgen_module.py | 12 +-- dimos/manipulation/manipulation_module.py | 14 +-- dimos/manipulation/pick_and_place_module.py | 14 ++- dimos/manipulation/planning/spec/config.py | 25 +++-- dimos/mapping/costmapper.py | 14 +-- dimos/mapping/osm/current_location_map.py | 6 +- dimos/mapping/osm/query.py | 7 +- dimos/mapping/voxels.py | 10 +- dimos/memory/embedding.py | 7 +- dimos/models/base.py | 7 +- dimos/models/embedding/base.py | 3 - dimos/models/embedding/clip.py | 2 - dimos/models/embedding/mobileclip.py | 2 - dimos/models/embedding/treid.py | 2 - dimos/models/vl/base.py | 24 +++-- dimos/models/vl/create.py | 4 +- dimos/models/vl/moondream.py | 9 +- dimos/models/vl/moondream_hosted.py | 13 +-- dimos/models/vl/openai.py | 5 +- dimos/models/vl/qwen.py | 5 +- dimos/navigation/bbox_navigation.py | 22 ++-- .../test_wavefront_frontier_goal_selector.py | 2 +- .../wavefront_frontier_goal_selector.py | 70 ++++++------ dimos/navigation/replanning_a_star/module.py | 10 +- dimos/navigation/rosnav.py | 18 ++-- dimos/navigation/visual/query.py | 3 +- dimos/perception/detection/conftest.py | 7 +- dimos/perception/detection/module2D.py | 25 ++--- .../temporal_memory/entity_graph_db.py | 2 +- .../temporal_memory/temporal_memory.py | 83 +++++++------- .../test_temporal_memory_module.py | 30 +++--- .../temporal_memory/window_analyzer.py | 6 +- dimos/perception/object_tracker.py | 20 ++-- dimos/perception/object_tracker_2d.py | 6 +- dimos/perception/perceive_loop_skill.py | 14 +-- dimos/perception/spatial_perception.py | 102 +++++++++--------- .../perception/test_spatial_memory_module.py | 17 +-- dimos/protocol/pubsub/bridge.py | 6 +- dimos/protocol/pubsub/impl/lcmpubsub.py | 14 +-- dimos/protocol/pubsub/impl/redispubsub.py | 9 +- dimos/protocol/service/__init__.py | 7 +- dimos/protocol/service/ddsservice.py | 11 +- dimos/protocol/service/lcmservice.py | 37 ++++--- dimos/protocol/service/spec.py | 15 ++- dimos/protocol/service/test_lcmservice.py | 50 +++++---- dimos/protocol/tf/tf.py | 28 ++--- dimos/protocol/tf/tflcmcpp.py | 9 +- dimos/robot/drone/connection_module.py | 39 ++++--- dimos/robot/foxglove_bridge.py | 36 +++---- dimos/robot/unitree/b1/connection.py | 30 ++++-- dimos/robot/unitree/b1/joystick_module.py | 9 +- dimos/robot/unitree/b1/unitree_b1.py | 4 +- dimos/robot/unitree/g1/connection.py | 47 ++++---- dimos/robot/unitree/g1/sim.py | 35 +++--- dimos/robot/unitree/go2/connection.py | 43 ++++---- dimos/robot/unitree/go2/fleet_connection.py | 50 +++++---- dimos/robot/unitree/keyboard_teleop.py | 5 +- dimos/robot/unitree/mujoco_connection.py | 1 + dimos/robot/unitree/rosnav.py | 6 +- dimos/robot/unitree/type/map.py | 53 +++++---- dimos/simulation/manipulators/sim_module.py | 7 +- .../manipulators/test_sim_module.py | 3 +- .../teleop/keyboard/keyboard_teleop_module.py | 6 +- dimos/teleop/phone/phone_teleop_module.py | 6 +- dimos/teleop/quest/quest_extensions.py | 17 ++- dimos/teleop/quest/quest_teleop_module.py | 14 +-- dimos/utils/cli/lcmspy/lcmspy.py | 14 ++- dimos/visualization/rerun/bridge.py | 25 ++--- .../web/websocket_vis/websocket_vis_module.py | 31 +++--- docs/usage/blueprints.md | 28 ++--- docs/usage/configuration.md | 17 +-- docs/usage/native_modules.md | 20 +--- docs/usage/transforms.md | 4 - examples/simplerobot/simplerobot.py | 12 +-- pyproject.toml | 8 +- 118 files changed, 959 insertions(+), 1103 deletions(-) diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 37e1a4757c..ab576fb109 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import json from queue import Empty, Queue from threading import Event, RLock, Thread @@ -38,7 +37,6 @@ from langchain_core.language_models import BaseChatModel -@dataclass class AgentConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -58,8 +56,8 @@ class Agent(Module[AgentConfig]): _thread: Thread _stop_event: Event - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/agent_test_runner.py b/dimos/agents/agent_test_runner.py index 7d7fbab03d..7a4ba2a94e 100644 --- a/dimos/agents/agent_test_runner.py +++ b/dimos/agents/agent_test_runner.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from threading import Event, Thread +from typing import Any from langchain_core.messages import AIMessage from langchain_core.messages.base import BaseMessage @@ -20,21 +22,26 @@ from dimos.agents.agent import AgentSpec from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out -class AgentTestRunner(Module): +class Config(ModuleConfig): + messages: Iterable[BaseMessage] + + +class AgentTestRunner(Module[Config]): + default_config = Config + agent_spec: AgentSpec agent: In[BaseMessage] agent_idle: In[bool] finished: Out[bool] added: Out[bool] - def __init__(self, messages: list[BaseMessage]) -> None: - super().__init__() - self._messages = messages + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._idle_event = Event() self._subscription_ready = Event() self._thread = Thread(target=self._thread_loop, daemon=True) @@ -71,7 +78,7 @@ def _thread_loop(self) -> None: if not self._subscription_ready.wait(5): raise TimeoutError("Timed out waiting for subscription to be ready.") - for message in self._messages: + for message in self.config.messages: self._idle_event.clear() self.agent_spec.add_message(message) if not self._idle_event.wait(60): diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 7c5eda5302..a2ee872e16 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from queue import Empty, Queue from threading import Event, RLock, Thread import time @@ -39,7 +38,6 @@ logger = setup_logger() -@dataclass class McpClientConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -62,8 +60,8 @@ class McpClient(Module[McpClientConfig]): _http_client: httpx.Client _seq_ids: SequentialIds - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index bfd45bc58a..e5697542fb 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import json import os import time @@ -22,7 +23,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from starlette.requests import Request # noqa: TC002 +from starlette.requests import Request from starlette.responses import Response import uvicorn @@ -32,14 +33,11 @@ from dimos.core.rpc_client import RpcCall, RPCClient from dimos.utils.logging_config import setup_logger -logger = setup_logger() - - if TYPE_CHECKING: - import concurrent.futures - from dimos.core.module import SkillInfo +logger = setup_logger() + app = FastAPI() app.add_middleware( @@ -185,10 +183,8 @@ async def mcp_endpoint(request: Request) -> Response: class McpServer(Module): - def __init__(self) -> None: - super().__init__() - self._uvicorn_server: uvicorn.Server | None = None - self._serve_future: concurrent.futures.Future[None] | None = None + _uvicorn_server: uvicorn.Server | None = None + _serve_future: concurrent.futures.Future[None] | None = None @rpc def start(self) -> None: diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py index 16427103e4..56b98c3cd2 100644 --- a/dimos/agents/mcp/test_mcp_client.py +++ b/dimos/agents/mcp/test_mcp_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -40,10 +41,8 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - def __init__(self): - super().__init__() - self._first_call = True - self._use_upper = False + _first_call = True + _use_upper = False @skill def register_user(self, name: str) -> str: @@ -79,8 +78,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index 7e402e32d7..c03932924f 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -32,8 +32,8 @@ class GoogleMapsSkillContainer(Module): gps_location: In[LatLon] - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) try: self._client = GoogleMaps() except ValueError: diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 721119f6e6..63cf4a3dd3 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -34,9 +34,6 @@ class GpsNavSkillContainer(Module): gps_location: In[LatLon] gps_goal: Out[LatLon] - def __init__(self) -> None: - super().__init__() - @rpc def start(self) -> None: super().start() diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index b02ff3a446..8442846f32 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -55,8 +55,8 @@ class NavigationSkillContainer(Module): color_image: In[Image] odom: In[PoseStamped] - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._skill_started = False # Here to prevent unwanted imports in the file. diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index e59ddb3b2a..7a6c6ecfe9 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -14,7 +14,7 @@ from threading import Event, RLock, Thread import time -from typing import TYPE_CHECKING +from typing import Any from langchain_core.messages import HumanMessage import numpy as np @@ -23,10 +23,11 @@ from dimos.agents.agent import AgentSpec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.models.qwen.bbox import BBox +from dimos.models.segmentation.edge_tam import EdgeTAMProcessor +from dimos.models.vl.base import VlModel from dimos.models.vl.create import create from dimos.msgs.geometry_msgs import Twist from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 @@ -35,14 +36,15 @@ from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.models.segmentation.edge_tam import EdgeTAMProcessor - from dimos.models.vl.base import VlModel - logger = setup_logger() -class PersonFollowSkillContainer(Module): +class Config(ModuleConfig): + camera_info: CameraInfo + use_3d_navigation: bool = False + + +class PersonFollowSkillContainer(Module[Config]): """Skill container for following a person. This skill uses: @@ -52,6 +54,8 @@ class PersonFollowSkillContainer(Module): - Does not do obstacle avoidance; assumes a clear path. """ + default_config = Config + color_image: In[Image] global_map: In[PointCloud2] cmd_vel: Out[Twist] @@ -60,38 +64,31 @@ class PersonFollowSkillContainer(Module): _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost - def __init__( - self, - camera_info: CameraInfo, - cfg: GlobalConfig, - use_3d_navigation: bool = False, - ) -> None: - super().__init__() - self._global_config: GlobalConfig = cfg - self._use_3d_navigation: bool = use_3d_navigation + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._latest_image: Image | None = None self._latest_pointcloud: PointCloud2 | None = None - self._vl_model: VlModel = create("qwen") + self._vl_model: VlModel[Any] = create("qwen") self._tracker: EdgeTAMProcessor | None = None self._thread: Thread | None = None self._should_stop: Event = Event() self._lock = RLock() # Use MuJoCo camera intrinsics in simulation mode - if self._global_config.simulation: + camera_info = self.config.camera_info + if self.config.g.simulation: from dimos.robot.unitree.mujoco_connection import MujocoConnection camera_info = MujocoConnection.camera_info_static - self._camera_info = camera_info - self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation) + self._visual_servo = VisualServoing2D(camera_info, self.config.g.simulation) self._detection_navigation = DetectionNavigation(self.tf, camera_info) @rpc def start(self) -> None: super().start() self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - if self._use_3d_navigation: + if self.config.use_3d_navigation: self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc @@ -230,7 +227,7 @@ def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None: lost_count = 0 best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume()) - if self._use_3d_navigation: + if self.config.use_3d_navigation: with self._lock: pointcloud = self._latest_pointcloud if pointcloud is None: diff --git a/dimos/agents/skills/test_google_maps_skill_container.py b/dimos/agents/skills/test_google_maps_skill_container.py index 1d8e4549b0..1519f9d1df 100644 --- a/dimos/agents/skills/test_google_maps_skill_container.py +++ b/dimos/agents/skills/test_google_maps_skill_container.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -39,8 +40,8 @@ def get_location_context(self, location, radius=200): class MockedWhereAmISkill(GoogleMapsSkillContainer): - def __init__(self): - Module.__init__(self) # Skip GoogleMapsSkillContainer's __init__. + def __init__(self, **kwargs: Any): + Module.__init__(self, **kwargs) # Skip GoogleMapsSkillContainer's __init__. self._client = FakeLocationClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True @@ -62,8 +63,8 @@ def get_position(self, query, location): class MockedPositionSkill(GoogleMapsSkillContainer): - def __init__(self): - Module.__init__(self) + def __init__(self, **kwargs: Any): + Module.__init__(self, **kwargs) self._client = FakePositionClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True diff --git a/dimos/agents/skills/test_gps_nav_skills.py b/dimos/agents/skills/test_gps_nav_skills.py index d701d469ca..4060b1814e 100644 --- a/dimos/agents/skills/test_gps_nav_skills.py +++ b/dimos/agents/skills/test_gps_nav_skills.py @@ -28,11 +28,9 @@ class FakeGPS(Module): class MockedGpsNavSkill(GpsNavSkillContainer): - def __init__(self): - Module.__init__(self) - self._latest_location = LatLon(lat=37.782654, lon=-122.413273) - self._started = True - self._max_valid_distance = 50000 + _latest_location = LatLon(lat=37.782654, lon=-122.413273) + _started = True + _max_valid_distance = 50000 @pytest.mark.slow diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py index a7505b23c7..e31fae93b5 100644 --- a/dimos/agents/skills/test_navigation.py +++ b/dimos/agents/skills/test_navigation.py @@ -31,23 +31,17 @@ class FakeOdom(Module): class MockedStopNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _cancel_goal_and_stop(self): pass class MockedExploreNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _start_exploration(self, timeout): return "Exploration completed successfuly" @@ -56,12 +50,9 @@ def _cancel_goal_and_stop(self): class MockedSemanticNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _navigate_by_tagged_location(self, query): return None diff --git a/dimos/agents/skills/test_unitree_skill_container.py b/dimos/agents/skills/test_unitree_skill_container.py index dde7239bbd..92b006dce5 100644 --- a/dimos/agents/skills/test_unitree_skill_container.py +++ b/dimos/agents/skills/test_unitree_skill_container.py @@ -13,6 +13,7 @@ # limitations under the License. import difflib +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -23,8 +24,8 @@ class MockedUnitreeSkill(UnitreeSkillContainer): rpc_calls: list[str] = [] - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) # Provide a fake RPC so the real execute_sport_command runs end-to-end. self._bound_rpc_calls["GO2Connection.publish_request"] = lambda *args, **kwargs: None diff --git a/dimos/agents/test_agent.py b/dimos/agents/test_agent.py index 2464e622ca..bb6caa6337 100644 --- a/dimos/agents/test_agent.py +++ b/dimos/agents/test_agent.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -40,10 +41,8 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - def __init__(self): - super().__init__() - self._first_call = True - self._use_upper = False + _first_call = True + _use_upper = False @skill def register_user(self, name: str) -> str: @@ -81,8 +80,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index ec0aec1442..c39f79830a 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import TYPE_CHECKING, Any from langchain.chat_models import init_chat_model @@ -31,24 +30,22 @@ logger = setup_logger() -@dataclass class VLMAgentConfig(ModuleConfig): model: str = "gpt-4o" system_prompt: str | None = SYSTEM_PROMPT -class VLMAgent(Module): +class VLMAgent(Module[VLMAgentConfig]): """Stream-first agent for vision queries with optional RPC access.""" - default_config: type[VLMAgentConfig] = VLMAgentConfig - config: VLMAgentConfig + default_config = VLMAgentConfig color_image: In[Image] query_stream: In[HumanMessage] answer_stream: Out[AIMessage] - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) if self.config.model.startswith("ollama:"): from dimos.agents.ollama_agent import ensure_ollama_model diff --git a/dimos/agents_deprecated/modules/base_agent.py b/dimos/agents_deprecated/modules/base_agent.py index 18ac15b317..d524861f77 100644 --- a/dimos/agents_deprecated/modules/base_agent.py +++ b/dimos/agents_deprecated/modules/base_agent.py @@ -21,7 +21,7 @@ from dimos.agents_deprecated.agent_types import AgentResponse from dimos.agents_deprecated.memory.base import AbstractAgentSemanticMemory from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.utils.logging_config import setup_logger @@ -34,32 +34,34 @@ logger = setup_logger() -class BaseAgentModule(BaseAgent, Module): # type: ignore[misc] +class BaseAgentConfig(ModuleConfig): + model: str = "openai::gpt-4o-mini" + system_prompt: str | None = None + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None + memory: AbstractAgentSemanticMemory | None = None + temperature: float = 0.0 + max_tokens: int = 4096 + max_input_tokens: int = 128000 + max_history: int = 20 + rag_n: int = 4 + rag_threshold: float = 0.45 + process_all_inputs: bool = False + + +class BaseAgentModule(BaseAgent, Module[BaseAgentConfig]): # type: ignore[misc] """Agent module that inherits from BaseAgent and adds DimOS module interface. This provides a thin wrapper around BaseAgent functionality, exposing it through the DimOS module system with RPC methods and stream I/O. """ + default_config = BaseAgentConfig + # Module I/O - AgentMessage based communication message_in: In[AgentMessage] # Primary input for AgentMessage response_out: Out[AgentResponse] # Output AgentResponse objects - def __init__( # type: ignore[no-untyped-def] - self, - model: str = "openai::gpt-4o-mini", - system_prompt: str | None = None, - skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, - memory: AbstractAgentSemanticMemory | None = None, - temperature: float = 0.0, - max_tokens: int = 4096, - max_input_tokens: int = 128000, - max_history: int = 20, - rag_n: int = 4, - rag_threshold: float = 0.45, - process_all_inputs: bool = False, - **kwargs, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the agent module. Args: @@ -82,17 +84,17 @@ def __init__( # type: ignore[no-untyped-def] # Initialize BaseAgent with all functionality BaseAgent.__init__( self, - model=model, - system_prompt=system_prompt, - skills=skills, - memory=memory, - temperature=temperature, - max_tokens=max_tokens, - max_input_tokens=max_input_tokens, - max_history=max_history, - rag_n=rag_n, - rag_threshold=rag_threshold, - process_all_inputs=process_all_inputs, + model=self.config.model, + system_prompt=self.config.system_prompt, + skills=self.config.skills, + memory=self.config.memory, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + max_input_tokens=self.config.max_input_tokens, + max_history=self.config.max_history, + rag_n=self.config.rag_n, + rag_threshold=self.config.rag_threshold, + process_all_inputs=self.config.process_all_inputs, # Don't pass streams - we'll connect them in start() input_query_stream=None, input_data_stream=None, diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 21d4c9d06c..73e036e873 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -104,7 +104,6 @@ class TaskConfig: gripper_closed_pos: float = 0.0 -@dataclass class ControlCoordinatorConfig(ModuleConfig): """Configuration for the ControlCoordinator. diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 287697f6c0..abfeb29b2f 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field, replace from functools import cached_property, reduce -import inspect import operator import sys from types import MappingProxyType @@ -27,7 +26,7 @@ from dimos.protocol.service.system_configurator.base import SystemConfigurator from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, is_module_type +from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport @@ -35,6 +34,11 @@ from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + logger = setup_logger() @@ -48,21 +52,18 @@ class StreamRef: @dataclass(frozen=True) class ModuleRef: name: str - spec: type[Spec] | type[Module] + spec: type[Spec] | type[ModuleBase] @dataclass(frozen=True) class _BlueprintAtom: - module: type[Module] + kwargs: dict[str, Any] + module: type[ModuleBase[Any]] streams: tuple[StreamRef, ...] module_refs: tuple[ModuleRef, ...] - args: tuple[Any, ...] - kwargs: dict[str, Any] @classmethod - def create( - cls, module: type[Module], args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> "_BlueprintAtom": + def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self: streams: list[StreamRef] = [] module_refs: list[ModuleRef] = [] @@ -103,7 +104,6 @@ def create( module=module, streams=tuple(streams), module_refs=tuple(module_refs), - args=args, kwargs=kwargs, ) @@ -111,23 +111,23 @@ def create( @dataclass(frozen=True) class Blueprint: blueprints: tuple[_BlueprintAtom, ...] - disabled_modules_tuple: tuple[type[Module], ...] = field(default_factory=tuple) + disabled_modules_tuple: tuple[type[ModuleBase], ...] = field(default_factory=tuple) transport_map: Mapping[tuple[str, type], PubSubTransport[Any]] = field( default_factory=lambda: MappingProxyType({}) ) global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) - remapping_map: Mapping[tuple[type[Module], str], str | type[Module] | type[Spec]] = field( - default_factory=lambda: MappingProxyType({}) + remapping_map: Mapping[tuple[type[ModuleBase], str], str | type[ModuleBase] | type[Spec]] = ( + field(default_factory=lambda: MappingProxyType({})) ) requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) configurator_checks: "tuple[SystemConfigurator, ...]" = field(default_factory=tuple) @classmethod - def create(cls, module: type[Module], *args: Any, **kwargs: Any) -> "Blueprint": - blueprint = _BlueprintAtom.create(module, args, kwargs) + def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint": + blueprint = _BlueprintAtom.create(module, kwargs) return cls(blueprints=(blueprint,)) - def disabled_modules(self, *modules: type[Module]) -> "Blueprint": + def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint": return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules) def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint": @@ -140,7 +140,10 @@ def global_config(self, **kwargs: Any) -> "Blueprint": ) def remappings( - self, remappings: list[tuple[type[Module], str, str | type[Module] | type[Spec]]] + self, + remappings: list[ + tuple[type[ModuleBase[Any]], str, str | type[ModuleBase[Any]] | type[Spec]] + ], ) -> "Blueprint": remappings_dict = dict(self.remapping_map) for module, old, new in remappings: @@ -163,8 +166,8 @@ def _active_blueprints(self) -> tuple[_BlueprintAtom, ...]: def _check_ambiguity( self, requested_method_name: str, - interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]], - requesting_module: type[Module], + interface_methods: Mapping[str, list[tuple[type[ModuleBase], Callable[..., Any]]]], + requesting_module: type[ModuleBase], ) -> None: if ( requested_method_name in interface_methods @@ -273,13 +276,9 @@ def _verify_no_name_conflicts(self) -> None: def _deploy_all_modules( self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig ) -> None: - module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = [] + module_specs: list[ModuleSpec] = [] for blueprint in self._active_blueprints: - kwargs = {**blueprint.kwargs} - sig = inspect.signature(blueprint.module.__init__) - if "cfg" in sig.parameters: - kwargs["cfg"] = global_config - module_specs.append((blueprint.module, blueprint.args, kwargs)) + module_specs.append((blueprint.module, global_config, blueprint.kwargs)) module_coordinator.deploy_parallel(module_specs) @@ -399,12 +398,12 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: rpc_methods_dot = {} # Track interface methods to detect ambiguity. - interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + interface_methods: defaultdict[str, list[tuple[type[ModuleBase], Callable[..., Any]]]] = ( defaultdict(list) ) # interface_name_method -> [(module_class, method)] - interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( - defaultdict(list) - ) # interface_name.method -> [(module_class, method)] + interface_methods_dot: defaultdict[ + str, list[tuple[type[ModuleBase], Callable[..., Any]]] + ] = defaultdict(list) # interface_name.method -> [(module_class, method)] for blueprint in self._active_blueprints: for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index ee56163ca6..99833a9b97 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -15,7 +15,7 @@ import argparse from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import field import importlib import json import os @@ -46,7 +46,6 @@ LOG_TAIL_LINES = 200 # Number of log lines to include in error messages -@dataclass(kw_only=True) class DockerModuleConfig(ModuleConfig): """ Configuration for running a DimOS module inside Docker. diff --git a/dimos/core/introspection/blueprint/dot.py b/dimos/core/introspection/blueprint/dot.py index ea66401033..74ee9406a9 100644 --- a/dimos/core/introspection/blueprint/dot.py +++ b/dimos/core/introspection/blueprint/dot.py @@ -31,7 +31,7 @@ color_for_string, sanitize_id, ) -from dimos.core.module import Module +from dimos.core.module import ModuleBase from dimos.utils.cli import theme @@ -82,11 +82,11 @@ def render( ignored_modules = DEFAULT_IGNORED_MODULES # Collect all outputs: (name, type) -> list of producer modules - producers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + producers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Collect all inputs: (name, type) -> list of consumer modules - consumers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + consumers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Module name -> module class (for getting package info) - module_classes: dict[str, type[Module]] = {} + module_classes: dict[str, type[ModuleBase]] = {} for bp in blueprint_set.blueprints: module_classes[bp.module.__name__] = bp.module @@ -117,7 +117,7 @@ def render( active_channels[key] = color_for_string(TYPE_COLORS, label) # Group modules by package - def get_group(mod_class: type[Module]) -> str: + def get_group(mod_class: type[ModuleBase]) -> str: module_path = mod_class.__module__ parts = module_path.split(".") if len(parts) >= 2 and parts[0] == "dimos": diff --git a/dimos/core/module.py b/dimos/core/module.py index 48a99a79a3..ab21ce17a9 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -17,38 +17,43 @@ from functools import partial import inspect import json +import sys import threading from typing import ( TYPE_CHECKING, Any, + Protocol, get_args, get_origin, get_type_hints, overload, ) -from typing_extensions import TypeVar as TypeVarExtension - -if TYPE_CHECKING: - from dimos.core.introspection.module import ModuleInfo - from dimos.core.rpc_client import RPCClient - -from typing import TypeVar - from langchain_core.tools import tool from reactivex.disposable import CompositeDisposable from dimos.core.core import T, rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module import extract_module_info, render_module_io from dimos.core.resource import Resource from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service import BaseConfig, Configurable from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +if TYPE_CHECKING: + from dimos.core.blueprints import Blueprint + from dimos.core.introspection.module import ModuleInfo + from dimos.core.rpc_client import RPCClient + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + @dataclass(frozen=True) class SkillInfo: @@ -70,20 +75,27 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: return loop, thr -@dataclass -class ModuleConfig: +class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC - tf_transport: type[TFSpec] = LCMTF + tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] frame_id_prefix: str | None = None frame_id: str | None = None + g: GlobalConfig = global_config + + +ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) -ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) +class _BlueprintPartial(Protocol): + def __call__(self, **kwargs: Any) -> "Blueprint": ... class ModuleBase(Configurable[ModuleConfigT], Resource): + # This won't type check against the TypeVar, but we need it as the default. + default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] + _rpc: RPCSpec | None = None - _tf: TFSpec | None = None + _tf: TFSpec[Any] | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None _disposables: CompositeDisposable @@ -93,10 +105,8 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] - default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) + def __init__(self, config_args: dict[str, Any]): + super().__init__(**config_args) self._module_closed_lock = threading.Lock() self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() @@ -338,7 +348,7 @@ def __get__( module_info = _module_info_descriptor() @classproperty - def blueprint(self): # type: ignore[no-untyped-def] + def blueprint(self) -> _BlueprintPartial: # Here to prevent circular imports. from dimos.core.blueprints import Blueprint @@ -409,7 +419,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not hasattr(cls, name) or getattr(cls, name) is None: setattr(cls, name, None) - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any): self.ref = None # type: ignore[assignment] try: @@ -427,7 +437,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) # type: ignore[assignment] setattr(self, name, stream) - super().__init__(*args, **kwargs) + super().__init__(config_args=kwargs) def __str__(self) -> str: return f"{self.__class__.__name__}" @@ -465,7 +475,7 @@ def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: input_stream.connection = remote_stream -ModuleT = TypeVar("ModuleT", bound="Module[Any]") +ModuleSpec = tuple[type[ModuleBase], GlobalConfig, dict[str, Any]] def is_module_type(value: Any) -> bool: diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 3a7961fcea..10227eae93 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.core.module import Module, ModuleT from dimos.core.resource_monitor.monitor import StatsMonitor from dimos.core.rpc_client import ModuleProxy from dimos.core.worker import Worker @@ -37,7 +37,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[Module], ModuleProxy] + _deployed_modules: dict[type[ModuleBase], ModuleProxy] _stats_monitor: StatsMonitor | None = None def __init__( @@ -115,17 +115,20 @@ def stop(self) -> None: self._client.close_all() # type: ignore[union-attr] - def deploy(self, module_class: type[ModuleT], *args, **kwargs) -> ModuleProxy: # type: ignore[no-untyped-def] + def deploy( + self, + module_class: type[ModuleBase[Any]], + global_config: GlobalConfig = global_config, + **kwargs: Any, + ) -> ModuleProxy: if not self._client: raise ValueError("Trying to dimos.deploy before the client has started") - module: ModuleProxy = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr, attr-defined, assignment] - self._deployed_modules[module_class] = module - return module + module = self._client.deploy(module_class, global_config, kwargs) + self._deployed_modules[module_class] = module # type: ignore[assignment] + return module # type: ignore[return-value] - def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]] - ) -> list[ModuleProxy]: + def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: if not self._client: raise ValueError("Not started") @@ -148,7 +151,7 @@ def start_all_modules(self) -> None: if hasattr(module, "on_system_modules"): module.on_system_modules(module_list) - def get_instance(self, module: type[ModuleT]) -> ModuleProxy: + def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] def loop(self) -> None: diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 6a93e6453a..f4a674cb5d 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,7 +40,6 @@ class MyCppModule(NativeModule): from __future__ import annotations -from dataclasses import dataclass, field, fields import enum import inspect import json @@ -48,13 +47,21 @@ class MyCppModule(NativeModule): from pathlib import Path import signal import subprocess +import sys import threading from typing import IO, Any +from pydantic import Field + from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() @@ -63,15 +70,14 @@ class LogFormat(enum.Enum): JSON = "json" -@dataclass(kw_only=True) class NativeModuleConfig(ModuleConfig): """Configuration for a native (C/C++) subprocess module.""" executable: str build_command: str | None = None cwd: str | None = None - extra_args: list[str] = field(default_factory=list) - extra_env: dict[str, str] = field(default_factory=dict) + extra_args: list[str] = Field(default_factory=list) + extra_env: dict[str, str] = Field(default_factory=dict) shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT @@ -85,26 +91,29 @@ def to_cli_args(self) -> list[str]: or its parents) and converts them to ``["--name", str(value)]`` pairs. Skips fields whose values are ``None`` and fields in ``cli_exclude``. """ - ignore_fields = {f.name for f in fields(NativeModuleConfig)} + ignore_fields = {f for f in NativeModuleConfig.model_fields} args: list[str] = [] - for f in fields(self): - if f.name in ignore_fields: + for f in self.__class__.model_fields: + if f in ignore_fields: continue - if f.name in self.cli_exclude: + if f in self.cli_exclude: continue - val = getattr(self, f.name) + val = getattr(self, f) if val is None: continue if isinstance(val, bool): - args.extend([f"--{f.name}", str(val).lower()]) + args.extend([f"--{f}", str(val).lower()]) elif isinstance(val, list): - args.extend([f"--{f.name}", ",".join(str(v) for v in val)]) + args.extend([f"--{f}", ",".join(str(v) for v in val)]) else: - args.extend([f"--{f.name}", str(val)]) + args.extend([f"--{f}", str(val)]) return args -class NativeModule(Module[NativeModuleConfig]): +_NativeConfig = TypeVar("_NativeConfig", bound=NativeModuleConfig, default=NativeModuleConfig) + + +class NativeModule(Module[_NativeConfig]): """Module that wraps a native executable as a managed subprocess. Subclass this, declare In/Out ports, and set ``default_config`` to a @@ -118,13 +127,13 @@ class NativeModule(Module[NativeModuleConfig]): LCM topics directly. On ``stop()``, the process receives SIGTERM. """ - default_config: type[NativeModuleConfig] = NativeModuleConfig + default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._resolve_paths() @rpc diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 30677bd1f7..19dbf62c74 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -113,14 +113,13 @@ class ModuleC(Module): def test_get_connection_set() -> None: - assert _BlueprintAtom.create(CatModule, args=("arg1",), kwargs={"k": "v"}) == _BlueprintAtom( + assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom( module=CatModule, streams=( StreamRef(name="pet_cat", type=Petting, direction="in"), StreamRef(name="scratches", type=Scratch, direction="out"), ), module_refs=(), - args=("arg1",), kwargs={"k": "v"}, ) @@ -137,7 +136,6 @@ def test_autoconnect() -> None: StreamRef(name="data2", type=Data2, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), _BlueprintAtom( @@ -148,7 +146,6 @@ def test_autoconnect() -> None: StreamRef(name="data3", type=Data3, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), ) @@ -342,11 +339,11 @@ def test_future_annotations_support() -> None: """ # Test that streams are properly extracted from modules with future annotations - out_blueprint = _BlueprintAtom.create(FutureModuleOut, args=(), kwargs={}) + out_blueprint = _BlueprintAtom.create(FutureModuleOut, kwargs={}) assert len(out_blueprint.streams) == 1 assert out_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="out") - in_blueprint = _BlueprintAtom.create(FutureModuleIn, args=(), kwargs={}) + in_blueprint = _BlueprintAtom.create(FutureModuleIn, kwargs={}) assert len(in_blueprint.streams) == 1 assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in") diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 197539ef67..3bd1383761 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -39,9 +39,6 @@ class Navigation(Module): @rpc def navigate_to(self, target: Vector3) -> bool: ... - def __init__(self) -> None: - super().__init__() - @rpc def start(self) -> None: def _odom(msg) -> None: diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index d17775130e..e77b8f9a53 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,7 +18,6 @@ The echo script writes received CLI args to a temp file for assertions. """ -from dataclasses import dataclass import json from pathlib import Path import time @@ -59,7 +58,6 @@ def read_json_file(path: str) -> dict[str, str]: return result -@dataclass(kw_only=True) class StubNativeConfig(NativeModuleConfig): executable: str = _ECHO log_format: LogFormat = LogFormat.TEXT diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index a7c949b33a..16cb44b907 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -15,6 +15,7 @@ from collections.abc import Callable import threading import time +from typing import Any import pytest @@ -28,15 +29,15 @@ class SubscriberBase(Module): - sub1_msgs: list[Odometry] = None - sub2_msgs: list[Odometry] = None + sub1_msgs: list[Odometry] + sub2_msgs: list[Odometry] - def __init__(self) -> None: + def __init__(self, **kwargs: Any) -> None: self.sub1_msgs = [] self.sub2_msgs = [] self._sub1_received = threading.Event() self._sub2_received = threading.Event() - super().__init__() + super().__init__(**kwargs) def _sub1_callback(self, msg) -> None: self.sub1_msgs.append(msg) diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index a5217f2dd6..306b3fdb3d 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -17,6 +17,7 @@ import pytest from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.core.worker_manager import WorkerManager @@ -99,7 +100,7 @@ def _create(n_workers): @pytest.mark.slow def test_worker_manager_basic(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module = worker_manager.deploy(SimpleModule) + module = worker_manager.deploy(SimpleModule, global_config, {}) module.start() result = module.increment() @@ -117,8 +118,8 @@ def test_worker_manager_basic(create_worker_manager): @pytest.mark.slow def test_worker_manager_multiple_different_modules(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module1 = worker_manager.deploy(SimpleModule) - module2 = worker_manager.deploy(AnotherModule) + module1 = worker_manager.deploy(SimpleModule, global_config, {}) + module2 = worker_manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() @@ -141,9 +142,9 @@ def test_worker_manager_parallel_deployment(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) modules = worker_manager.deploy_parallel( [ - (SimpleModule, (), {}), - (AnotherModule, (), {}), - (ThirdModule, (), {}), + (SimpleModule, global_config, {}), + (AnotherModule, global_config, {}), + (ThirdModule, global_config, {}), ] ) @@ -175,8 +176,8 @@ def test_collect_stats(create_worker_manager): from dimos.core.resource_monitor.monitor import StatsMonitor manager = create_worker_manager(n_workers=2) - module1 = manager.deploy(SimpleModule) - module2 = manager.deploy(AnotherModule) + module1 = manager.deploy(SimpleModule, global_config, {}) + module2 = manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() @@ -219,8 +220,8 @@ def log_stats(self, coordinator, workers): @pytest.mark.slow def test_worker_pool_modules_share_workers(create_worker_manager): manager = create_worker_manager(n_workers=1) - module1 = manager.deploy(SimpleModule) - module2 = manager.deploy(AnotherModule) + module1 = manager.deploy(SimpleModule, global_config, {}) + module2 = manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 6431c09dbd..3bb5865192 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -14,6 +14,7 @@ from threading import Event, Thread import time +from typing import Any from dimos.core.core import rpc from dimos.core.module import Module @@ -32,8 +33,8 @@ class MockRobotClient(Module): mov_msg_count = 0 - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = Event() self._thread = None diff --git a/dimos/core/worker.py b/dimos/core/worker.py index 3a98e6b7ba..dca561f16c 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -15,19 +15,19 @@ import logging import multiprocessing +from multiprocessing.connection import Connection import os import sys import threading import traceback from typing import TYPE_CHECKING, Any +from dimos.core.global_config import GlobalConfig, global_config from dimos.utils.logging_config import setup_logger from dimos.utils.sequential_ids import SequentialIds if TYPE_CHECKING: - from multiprocessing.connection import Connection - - from dimos.core.module import ModuleT + from dimos.core.module import ModuleBase logger = setup_logger() @@ -75,7 +75,7 @@ class Actor: def __init__( self, conn: Connection | None, - module_class: type[ModuleT], + module_class: type[ModuleBase], worker_id: int, module_id: int = 0, lock: threading.Lock | None = None, @@ -143,8 +143,6 @@ def reset_forkserver_context() -> None: class Worker: - """Generic worker process that can host multiple modules.""" - def __init__(self) -> None: self._lock = threading.Lock() self._modules: dict[int, Actor] = {} @@ -198,14 +196,15 @@ def start_process(self) -> None: def deploy_module( self, - module_class: type[ModuleT], - args: tuple[Any, ...] = (), - kwargs: dict[Any, Any] | None = None, + module_class: type[ModuleBase], + global_config: GlobalConfig = global_config, + kwargs: dict[str, Any] | None = None, ) -> Actor: if self._conn is None: raise RuntimeError("Worker process not started") kwargs = kwargs or {} + kwargs["g"] = global_config module_id = _module_ids.next() # Send deploy_module request to the worker process @@ -213,7 +212,6 @@ def deploy_module( "type": "deploy_module", "module_id": module_id, "module_class": module_class, - "args": args, "kwargs": kwargs, } with self._lock: @@ -293,10 +291,7 @@ def _suppress_console_output() -> None: ] -def _worker_entrypoint( - conn: Connection, - worker_id: int, -) -> None: +def _worker_entrypoint(conn: Connection, worker_id: int) -> None: instances: dict[int, Any] = {} try: @@ -346,10 +341,9 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) -> if req_type == "deploy_module": module_class = request["module_class"] - args = request.get("args", ()) - kwargs = request.get("kwargs", {}) + kwargs = request["kwargs"] module_id = request["module_id"] - instance = module_class(*args, **kwargs) + instance = module_class(**kwargs) instances[module_id] = instance response["result"] = module_id diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 2b41f634e8..4cd5eec8d7 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -14,16 +14,16 @@ from __future__ import annotations +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +from typing import Any +from dimos.core.global_config import GlobalConfig +from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.core.module import ModuleT - logger = setup_logger() @@ -47,7 +47,9 @@ def start(self) -> None: def _select_worker(self) -> Worker: return min(self._workers, key=lambda w: w.module_count) - def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient: + def deploy( + self, module_class: type[ModuleBase], global_config: GlobalConfig, kwargs: dict[str, Any] + ) -> RPCClient: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -56,12 +58,10 @@ def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCC self.start() worker = self._select_worker() - actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) + actor = worker.deploy_module(module_class, global_config, kwargs=kwargs) return RPCClient(actor, module_class) - def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[Any, Any]]] - ) -> list[RPCClient]: + def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -72,17 +72,17 @@ def deploy_parallel( # Pre-assign workers sequentially (so least-loaded accounting is # correct), then deploy concurrently via threads. The per-worker lock # serializes deploys that land on the same worker process. - assignments: list[tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]]] = [] - for module_class, args, kwargs in module_specs: + assignments: list[tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]]] = [] + for module_class, global_config, kwargs in module_specs: worker = self._select_worker() worker.reserve_slot() - assignments.append((worker, module_class, args, kwargs)) + assignments.append((worker, module_class, global_config, kwargs)) def _deploy( - item: tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]], + item: tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]], ) -> RPCClient: - worker, module_class, args, kwargs = item - actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) + worker, module_class, global_config, kwargs = item + actor = worker.deploy_module(module_class, global_config=global_config, kwargs=kwargs) return RPCClient(actor, module_class) with ThreadPoolExecutor(max_workers=len(assignments)) as pool: diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py index 9161185d50..ec19d6844e 100644 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import logging import sys import threading import time +from typing import Any import numpy as np @@ -43,28 +43,22 @@ Gst.init(None) -@dataclass class Config(ModuleConfig): frame_id: str = "camera" + host: str = "localhost" + port: int = 5000 + timestamp_offset: float = 0.0 + reconnect_interval: float = 5.0 -class GstreamerCameraModule(Module): +class GstreamerCameraModule(Module[Config]): """Module that captures frames from a remote camera using GStreamer TCP with absolute timestamps.""" default_config = Config - config: Config video: Out[Image] - def __init__( # type: ignore[no-untyped-def] - self, - host: str = "localhost", - port: int = 5000, - timestamp_offset: float = 0.0, - reconnect_interval: float = 5.0, - *args, - **kwargs, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the GStreamer TCP camera module. Args: @@ -74,10 +68,10 @@ def __init__( # type: ignore[no-untyped-def] timestamp_offset: Offset to add to timestamps (useful for clock synchronization) reconnect_interval: Seconds to wait before attempting reconnection """ - self.host = host - self.port = port - self.timestamp_offset = timestamp_offset - self.reconnect_interval = reconnect_interval + super().__init__(**kwargs) + self.host = self.config.host + self.port = self.config.port + self.reconnect_interval = self.config.reconnect_interval self.pipeline = None self.appsink = None @@ -88,7 +82,6 @@ def __init__( # type: ignore[no-untyped-def] self.frame_count = 0 self.last_log_time = time.time() self.reconnect_timer_id = None - super().__init__(**kwargs) @rpc def start(self) -> None: @@ -257,7 +250,7 @@ def _on_new_sample(self, appsink): # type: ignore[no-untyped-def] if buffer.pts != Gst.CLOCK_TIME_NONE: # Convert nanoseconds to seconds and add offset # This is the absolute time from when the frame was captured - timestamp = (buffer.pts / 1e9) + self.timestamp_offset + timestamp = (buffer.pts / 1e9) + self.config.timestamp_offset # Skip frames with invalid timestamps (before year 2000) # This filters out initial gray frames with relative timestamps diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 11821d4724..0f055f0352 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -13,16 +13,15 @@ # limitations under the License. from collections.abc import Callable -from dataclasses import dataclass, field import time from typing import Any +from pydantic import Field import reactivex as rx from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -43,10 +42,9 @@ def default_transform() -> Transform: ) -@dataclass class CameraModuleConfig(ModuleConfig): frame_id: str = "camera_link" - transform: Transform | None = field(default_factory=default_transform) + transform: Transform | None = Field(default_factory=default_transform) hardware: Callable[[], CameraHardware[Any]] | CameraHardware[Any] = Webcam frequency: float = 0.0 # Hz, 0 means no limit @@ -55,16 +53,9 @@ class CameraModule(Module[CameraModuleConfig], perception.Camera): color_image: Out[Image] camera_info: Out[CameraInfo] - hardware: CameraHardware[Any] - - config: CameraModuleConfig default_config = CameraModuleConfig - _global_config: GlobalConfig - - def __init__(self, *args: Any, cfg: GlobalConfig = global_config, **kwargs: Any) -> None: - self._global_config = cfg - self._latest_image: Image | None = None - super().__init__(*args, **kwargs) + hardware: CameraHardware[Any] + _latest_image: Image | None = None @rpc def start(self) -> None: diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index f34b9a2881..5908525826 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -15,13 +15,13 @@ from __future__ import annotations import atexit -from dataclasses import dataclass, field import threading import time from typing import TYPE_CHECKING import cv2 import numpy as np +from pydantic import Field import reactivex as rx from scipy.spatial.transform import Rotation # type: ignore[import-untyped] @@ -55,14 +55,13 @@ def default_base_transform() -> Transform: ) -@dataclass class RealSenseCameraConfig(ModuleConfig, DepthCameraConfig): width: int = 848 height: int = 480 fps: int = 15 camera_name: str = "camera" base_frame_id: str = "base_link" - base_transform: Transform | None = field(default_factory=default_base_transform) + base_transform: Transform | None = Field(default_factory=default_base_transform) align_depth_to_color: bool = True enable_depth: bool = True enable_pointcloud: bool = False @@ -71,14 +70,13 @@ class RealSenseCameraConfig(ModuleConfig, DepthCameraConfig): serial_number: str | None = None -class RealSenseCamera(DepthCameraHardware, Module, perception.DepthCamera): +class RealSenseCamera(DepthCameraHardware, Module[RealSenseCameraConfig], perception.DepthCamera): color_image: Out[Image] depth_image: Out[Image] pointcloud: Out[PointCloud2] camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] - config: RealSenseCameraConfig default_config = RealSenseCameraConfig @property diff --git a/dimos/hardware/sensors/camera/spec.py b/dimos/hardware/sensors/camera/spec.py index 23fd1a076e..be37ec734a 100644 --- a/dimos/hardware/sensors/camera/spec.py +++ b/dimos/hardware/sensors/camera/spec.py @@ -13,19 +13,19 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TypeVar from reactivex.observable import Observable from dimos.msgs.geometry_msgs import Quaternion, Transform from dimos.msgs.sensor_msgs import CameraInfo from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Configurable OPTICAL_ROTATION = Quaternion(-0.5, 0.5, -0.5, 0.5) -class CameraConfig(Protocol): +class CameraConfig(BaseConfig): frame_id_prefix: str | None width: int height: int @@ -35,7 +35,7 @@ class CameraConfig(Protocol): CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) -class CameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): +class CameraHardware(ABC, Configurable[CameraConfigT]): @abstractmethod def image_stream(self) -> Observable[Image]: pass @@ -62,8 +62,6 @@ class DepthCameraConfig(CameraConfig): class DepthCameraHardware(ABC): """Abstract class for depth camera modules (RealSense, ZED, etc.).""" - config: DepthCameraConfig - @abstractmethod def get_color_camera_info(self) -> CameraInfo | None: """Get color camera intrinsics.""" diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/__init__.py index f8e73273bf..6e3b905e90 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/__init__.py @@ -18,15 +18,15 @@ from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider -# Check if ZED SDK is available try: - import pyzed.sl as sl # noqa: F401 + import pyzed.sl # noqa: F401 + # This awkwardness is needed as pytest implicitly imports this to collect + # the test in this directory. HAS_ZED_SDK = True except ImportError: HAS_ZED_SDK = False -# Only import ZED classes if SDK is available if HAS_ZED_SDK: from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera else: @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." ) - def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[no-redef] + def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[misc,no-redef] raise ModuleNotFoundError( "ZED SDK not installed. Please install pyzed package to use ZED camera functionality.", name="pyzed", diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index 6ce2fc86b2..2df9afd70c 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -15,11 +15,11 @@ from __future__ import annotations import atexit -from dataclasses import dataclass, field import threading import time import cv2 +from pydantic import Field import pyzed.sl as sl import reactivex as rx @@ -50,14 +50,13 @@ def default_base_transform() -> Transform: ) -@dataclass class ZEDCameraConfig(ModuleConfig, DepthCameraConfig): width: int = 1280 height: int = 720 fps: int = 15 camera_name: str = "camera" base_frame_id: str = "base_link" - base_transform: Transform | None = field(default_factory=default_base_transform) + base_transform: Transform | None = Field(default_factory=default_base_transform) align_depth_to_color: bool = True enable_depth: bool = True enable_pointcloud: bool = False @@ -76,14 +75,13 @@ class ZEDCameraConfig(ModuleConfig, DepthCameraConfig): world_frame: str = "world" -class ZEDCamera(DepthCameraHardware, Module, perception.DepthCamera): +class ZEDCamera(DepthCameraHardware, Module[ZEDCameraConfig], perception.DepthCamera): color_image: Out[Image] depth_image: Out[Image] pointcloud: Out[PointCloud2] camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] - config: ZEDCameraConfig default_config = ZEDCameraConfig @property diff --git a/dimos/hardware/sensors/camera/zed/test_zed.py b/dimos/hardware/sensors/camera/zed/test_zed.py index 2d912553c6..2716e809a5 100644 --- a/dimos/hardware/sensors/camera/zed/test_zed.py +++ b/dimos/hardware/sensors/camera/zed/test_zed.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from dimos.hardware.sensors.camera import zed from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +@pytest.mark.skipif(not zed.HAS_ZED_SDK, reason="ZED SDK not installed") def test_zed_import_and_calibration_access() -> None: """Test that zed module can be imported and calibrations accessed.""" - # Import zed module from camera - from dimos.hardware.sensors.camera import zed - # Test that CameraInfo is accessible assert hasattr(zed, "CameraInfo") diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index ec5613077d..ca5014337b 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -17,9 +17,9 @@ FakeZEDModule - Replays recorded ZED data for testing without hardware. """ -from dataclasses import dataclass import functools import logging +from typing import Any from dimos_lcm.sensor_msgs import CameraInfo import numpy as np @@ -37,8 +37,8 @@ logger = setup_logger(level=logging.INFO) -@dataclass class FakeZEDModuleConfig(ModuleConfig): + recording_path: str frame_id: str = "zed_camera" @@ -54,9 +54,8 @@ class FakeZEDModule(Module[FakeZEDModuleConfig]): pose: Out[PoseStamped] default_config = FakeZEDModuleConfig - config: FakeZEDModuleConfig - def __init__(self, recording_path: str, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize FakeZEDModule with recording path. @@ -65,7 +64,7 @@ def __init__(self, recording_path: str, **kwargs: object) -> None: """ super().__init__(**kwargs) - self.recording_path = recording_path + self.recording_path = self.config.recording_path self._running = False # Initialize TF publisher diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index fb894ddce5..c1a96a525b 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -30,12 +30,13 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated + +from pydantic.experimental.pipeline import validate_as from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out # noqa: TC001 +from dimos.core.stream import Out from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -48,14 +49,13 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.nav_msgs.Odometry import Odometry # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import mapping, perception _CONFIG_DIR = Path(__file__).parent / "config" -@dataclass(kw_only=True) class FastLio2Config(NativeModuleConfig): """Config for the FAST-LIO2 + Livox Mid-360 native module.""" @@ -92,7 +92,9 @@ class FastLio2Config(NativeModuleConfig): # FAST-LIO YAML config (relative to config/ dir, or absolute path) # C++ binary reads YAML directly via yaml-cpp - config: str = "mid360.yaml" + config: Annotated[ + Path, validate_as(...).transform(lambda p: p if p.is_absolute() else _CONFIG_DIR / p) + ] = Path("mid360.yaml") # SDK port configuration (see livox/ports.py for defaults) cmd_data_port: int = SDK_CMD_DATA_PORT @@ -112,15 +114,10 @@ class FastLio2Config(NativeModuleConfig): # config is not a CLI arg (config_path is) cli_exclude: frozenset[str] = frozenset({"config"}) - def __post_init__(self) -> None: - if self.config_path is None: - path = Path(self.config) - if not path.is_absolute(): - path = _CONFIG_DIR / path - self.config_path = str(path.resolve()) - -class FastLio2(NativeModule, perception.Lidar, perception.Odometry, mapping.GlobalPointcloud): +class FastLio2( + NativeModule[FastLio2Config], perception.Lidar, perception.Odometry, mapping.GlobalPointcloud +): """FAST-LIO2 SLAM module with integrated Livox Mid-360 driver. Ports: @@ -129,7 +126,7 @@ class FastLio2(NativeModule, perception.Lidar, perception.Odometry, mapping.Glob global_map (Out[PointCloud2]): Global voxel map (optional, enable via map_freq > 0). """ - default_config: type[FastLio2Config] = FastLio2Config # type: ignore[assignment] + default_config = FastLio2Config lidar: Out[PointCloud2] odometry: Out[Odometry] global_map: Out[PointCloud2] diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 2e470b21ef..999cdd9aa1 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -26,11 +26,10 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out # noqa: TC001 +from dimos.core.stream import Out from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -43,12 +42,11 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.sensor_msgs.Imu import Imu # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import perception -@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): """Config for the C++ Mid-360 native module.""" @@ -76,7 +74,7 @@ class Mid360Config(NativeModuleConfig): host_log_data_port: int = SDK_HOST_LOG_DATA_PORT -class Mid360(NativeModule, perception.Lidar, perception.IMU): +class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): """Livox Mid-360 LiDAR module backed by a native C++ binary. Ports: diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py index 2c11b0cc10..7dd7e8c119 100644 --- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py +++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py @@ -26,7 +26,6 @@ - Supports velocity-based and position-based control modes """ -from dataclasses import dataclass import math import threading import time @@ -43,10 +42,11 @@ logger = setup_logger() -@dataclass class CartesianMotionControllerConfig(ModuleConfig): """Configuration for Cartesian motion controller.""" + arm_driver: Any = None + # Control loop parameters control_frequency: float = 20.0 # Hz - Cartesian control loop rate command_timeout: float = 30.0 # seconds - timeout for stale targets (RPC mode needs longer) @@ -78,7 +78,7 @@ class CartesianMotionControllerConfig(ModuleConfig): control_frame: str = "world" # Frame for target poses (world, base_link, etc.) -class CartesianMotionController(Module): +class CartesianMotionController(Module[CartesianMotionControllerConfig]): """ Hardware-agnostic Cartesian motion controller. @@ -94,7 +94,6 @@ class CartesianMotionController(Module): """ default_config = CartesianMotionControllerConfig - config: CartesianMotionControllerConfig # Type hint for proper attribute access # RPC methods to request from other modules (resolved at blueprint build time) rpc_calls = [ @@ -112,7 +111,7 @@ class CartesianMotionController(Module): cartesian_velocity: Out[Twist] = None # type: ignore[assignment] current_pose: Out[PoseStamped] = None # type: ignore[assignment] - def __init__(self, arm_driver: Any = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the Cartesian motion controller. @@ -120,10 +119,10 @@ def __init__(self, arm_driver: Any = None, *args: Any, **kwargs: Any) -> None: arm_driver: (Optional) Hardware driver reference (legacy mode). When using blueprints, this is resolved automatically via rpc_calls. """ - super().__init__(*args, **kwargs) + super().__init__(**kwargs) # Hardware driver reference - set via arm_driver param (legacy) or RPC wiring (blueprint) - self._arm_driver_legacy = arm_driver + self._arm_driver_legacy = self.config.arm_driver # State tracking self._latest_joint_state: JointState | None = None diff --git a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py index 1ce3149dd2..ebc6f3f53c 100644 --- a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py +++ b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py @@ -29,7 +29,6 @@ - reset(): Required to recover from FAULT state """ -from dataclasses import dataclass import threading import time from typing import Any @@ -44,14 +43,13 @@ logger = setup_logger() -@dataclass class JointTrajectoryControllerConfig(ModuleConfig): """Configuration for joint trajectory controller.""" control_frequency: float = 100.0 # Hz - trajectory execution rate -class JointTrajectoryController(Module): +class JointTrajectoryController(Module[JointTrajectoryControllerConfig]): """ Joint-space trajectory executor. @@ -72,7 +70,6 @@ class JointTrajectoryController(Module): """ default_config = JointTrajectoryControllerConfig - config: JointTrajectoryControllerConfig # Type hint for proper attribute access # Input topics joint_state: In[JointState] = None # type: ignore[assignment] # Feedback from arm driver @@ -82,8 +79,8 @@ class JointTrajectoryController(Module): # Output topics joint_position_command: Out[JointCommand] = None # type: ignore[assignment] # To arm driver - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # State machine self._state = TrajectoryState.IDLE diff --git a/dimos/manipulation/grasping/graspgen_module.py b/dimos/manipulation/grasping/graspgen_module.py index c988d3df51..7ec8cfeeaa 100644 --- a/dimos/manipulation/grasping/graspgen_module.py +++ b/dimos/manipulation/grasping/graspgen_module.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -from dataclasses import dataclass import os from pathlib import Path import sys @@ -42,7 +41,6 @@ COLLISION_FILTER_THRESHOLD = 0.02 -@dataclass class GraspGenConfig(DockerModuleConfig): """Configuration for GraspGen module.""" @@ -68,11 +66,9 @@ class GraspGenModule(Module[GraspGenConfig]): default_config = GraspGenConfig grasps: Out[PoseArray] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._sampler = self._gripper_info = None - self._initialized = False + _sampler = None + _gripper_info = None + _initialized = False @rpc def start(self) -> None: @@ -212,7 +208,7 @@ def _run_inference( return grasps_np, scores_np pc_mean = object_pc_filtered.mean(axis=0) - T_center = tra.translation_matrix(-pc_mean) + T_center = tra.translation_matrix(-pc_mean) # type: ignore[no-untyped-call] grasps_centered = np.array([T_center @ g for g in grasps_np]) scene_pc_centered = tra.transform_points(scene_pc, T_center) diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 40dd6734c5..cab6c9f173 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -24,7 +24,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Iterable from enum import Enum import threading import time @@ -82,18 +82,17 @@ class ManipulationState(Enum): FAULT = 4 -@dataclass class ManipulationModuleConfig(ModuleConfig): """Configuration for ManipulationModule.""" - robots: list[RobotModelConfig] = field(default_factory=list) + robots: Iterable[RobotModelConfig] = () planning_timeout: float = 10.0 enable_viz: bool = False planner_name: str = "rrt_connect" # "rrt_connect" kinematics_name: str = "jacobian" # "jacobian" or "drake_optimization" -class ManipulationModule(Module): +class ManipulationModule(Module[ManipulationModuleConfig]): """Base motion planning module with ControlCoordinator execution. - @rpc: Low-level building blocks (plan, execute, gripper) @@ -104,14 +103,11 @@ class ManipulationModule(Module): default_config = ManipulationModuleConfig - # Type annotation for the config attribute (mypy uses this) - config: ManipulationModuleConfig - # Input: Joint state from coordinator (for world sync) joint_state: In[JointState] - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # State machine self._state = ManipulationState.IDLE diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index 84ede61793..2016abeb4f 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -22,7 +22,6 @@ from __future__ import annotations -from dataclasses import dataclass, field import math from pathlib import Path import time @@ -32,7 +31,7 @@ from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc from dimos.core.docker_runner import DockerModule as DockerRunner -from dimos.core.stream import In # noqa: TC001 +from dimos.core.stream import In from dimos.manipulation.grasping.graspgen_module import GraspGenModule from dimos.manipulation.manipulation_module import ( ManipulationModule, @@ -40,7 +39,7 @@ ) from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 from dimos.perception.detection.type.detection3d.object import ( - Object as DetObject, # noqa: TC001 + Object as DetObject, ) from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger @@ -56,7 +55,6 @@ _GRASPGEN_VIZ_CONTAINER_PATH = f"{_GRASPGEN_VIZ_CONTAINER_DIR}/visualization.json" -@dataclass class PickAndPlaceModuleConfig(ManipulationModuleConfig): """Configuration for PickAndPlaceModule (adds GraspGen settings).""" @@ -68,8 +66,8 @@ class PickAndPlaceModuleConfig(ManipulationModuleConfig): graspgen_grasp_threshold: float = -1.0 graspgen_filter_collisions: bool = False graspgen_save_visualization_data: bool = False - graspgen_visualization_output_path: Path = field( - default_factory=lambda: Path.home() / ".dimos" / "graspgen" / "visualization.json" + graspgen_visualization_output_path: Path = ( + Path.home() / ".dimos" / "graspgen" / "visualization.json" ) @@ -90,8 +88,8 @@ class PickAndPlaceModule(ManipulationModule): # Input: Objects from perception (for obstacle integration) objects: In[list[DetObject]] - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # GraspGen Docker runner (lazy initialized on first generate_grasps call) self._graspgen: DockerRunner | None = None diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index dc302689ea..e379fc1eb5 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -16,17 +16,16 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from collections.abc import Iterable, Sequence +from pathlib import Path -if TYPE_CHECKING: - from pathlib import Path +from pydantic import Field - from dimos.msgs.geometry_msgs import PoseStamped +from dimos.core.module import ModuleConfig +from dimos.msgs.geometry_msgs import PoseStamped -@dataclass -class RobotModelConfig: +class RobotModelConfig(ModuleConfig): """Configuration for adding a robot to the world. Attributes: @@ -60,24 +59,24 @@ class RobotModelConfig: joint_names: list[str] end_effector_link: str base_link: str = "base_link" - package_paths: dict[str, Path] = field(default_factory=dict) + package_paths: dict[str, Path] = Field(default_factory=dict) joint_limits_lower: list[float] | None = None joint_limits_upper: list[float] | None = None velocity_limits: list[float] | None = None auto_convert_meshes: bool = False - xacro_args: dict[str, str] = field(default_factory=dict) - collision_exclusion_pairs: list[tuple[str, str]] = field(default_factory=list) + xacro_args: dict[str, str] = Field(default_factory=dict) + collision_exclusion_pairs: Iterable[tuple[str, str]] = () # Motion constraints for trajectory generation max_velocity: float = 1.0 max_acceleration: float = 2.0 # Coordinator integration - joint_name_mapping: dict[str, str] = field(default_factory=dict) + joint_name_mapping: dict[str, str] = Field(default_factory=dict) coordinator_task_name: str | None = None gripper_hardware_id: str | None = None # TF publishing for extra links (e.g., camera mount) - tf_extra_links: list[str] = field(default_factory=list) + tf_extra_links: Sequence[str] = () # Home/observe joint configuration for go_home skill - home_joints: list[float] | None = None + home_joints: Iterable[float] | None = None # Pre-grasp offset distance in meters (along approach direction) pre_grasp_offset: float = 0.10 diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index fa0ce826f2..75b674b2a0 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass, field +from dataclasses import asdict import time +from pydantic import Field from reactivex import operators as ops from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.pointclouds.occupancy import ( @@ -33,23 +33,17 @@ logger = setup_logger() -@dataclass class Config(ModuleConfig): algo: str = "height_cost" - config: OccupancyConfig = field(default_factory=HeightCostConfig) + config: OccupancyConfig = Field(default_factory=HeightCostConfig) -class CostMapper(Module): +class CostMapper(Module[Config]): default_config = Config - config: Config global_map: In[PointCloud2] global_costmap: Out[OccupancyGrid] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: - super().__init__(**kwargs) - self._global_config = cfg - @rpc def start(self) -> None: super().start() diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index ef0a832cd6..832116e25c 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from PIL import Image as PILImage, ImageDraw from dimos.mapping.osm.osm import MapImage, get_osm_map @@ -24,11 +26,11 @@ class CurrentLocationMap: - _vl_model: VlModel + _vl_model: VlModel[Any] _position: LatLon | None _map_image: MapImage | None - def __init__(self, vl_model: VlModel) -> None: + def __init__(self, vl_model: VlModel[Any]) -> None: self._vl_model = vl_model self._position = None self._map_image = None diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index 410f879c20..17fbfe3d4b 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from typing import Any from dimos.mapping.osm.osm import MapImage from dimos.mapping.types import LatLon @@ -25,7 +26,9 @@ logger = setup_logger() -def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: +def query_for_one_position( + vl_model: VlModel[Any], map_image: MapImage, query: str +) -> LatLon | None: full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." response = vl_model.query(map_image.image, full_query) coords = tuple(map(int, re.findall(r"\d+", response))) @@ -35,7 +38,7 @@ def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) - def query_for_one_position_and_context( - vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon + vl_model: VlModel[Any], map_image: MapImage, query: str, robot_position: LatLon ) -> tuple[LatLon, str] | None: example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' x, y = map_image.latlon_to_pixel(robot_position) diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 124073cf49..c2078dc309 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import time +from typing import Any import numpy as np import open3d as o3d # type: ignore[import-untyped] @@ -23,7 +23,6 @@ from reactivex.subject import Subject from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.sensor_msgs import PointCloud2 @@ -34,7 +33,6 @@ logger = setup_logger() -@dataclass class Config(ModuleConfig): frame_id: str = "world" # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds @@ -45,16 +43,14 @@ class Config(ModuleConfig): carve_columns: bool = True -class VoxelGridMapper(Module): +class VoxelGridMapper(Module[Config]): default_config = Config - config: Config lidar: In[PointCloud2] global_map: Out[PointCloud2] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._global_config = cfg dev = ( o3c.Device(self.config.device) diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index 4627ecfc35..e09e069f05 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -13,9 +13,10 @@ # limitations under the License. from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import cast +from pydantic import Field import reactivex as rx from reactivex import operators as ops from reactivex.observable import Observable @@ -32,9 +33,8 @@ from dimos.utils.reactive import getter_hot -@dataclass class Config(ModuleConfig): - embedding_model: EmbeddingModel = field(default_factory=CLIPModel) + embedding_model: EmbeddingModel = Field(default_factory=CLIPModel) @dataclass @@ -50,7 +50,6 @@ class SpatialEmbedding(SpatialEntry): class EmbeddingMemory(Module[Config]): default_config = Config - config: Config color_image: In[Image] global_costmap: In[OccupancyGrid] diff --git a/dimos/models/base.py b/dimos/models/base.py index 2269a6d0b8..d03ce5c539 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -16,21 +16,19 @@ from __future__ import annotations -from dataclasses import dataclass from functools import cached_property from typing import Annotated, Any import torch from dimos.core.resource import Resource -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service import BaseConfig, Configurable # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] -@dataclass -class LocalModelConfig: +class LocalModelConfig(BaseConfig): device: DeviceType = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.float32 warmup: bool = False @@ -127,7 +125,6 @@ def _ensure_cuda_initialized(self) -> None: pass -@dataclass class HuggingFaceModelConfig(LocalModelConfig): model_name: str = "" trust_remote_code: bool = True diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index c6b78fcf2c..520818aabf 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -15,7 +15,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass import time from typing import TYPE_CHECKING @@ -29,14 +28,12 @@ from dimos.msgs.sensor_msgs import Image -@dataclass class EmbeddingModelConfig(LocalModelConfig): """Base config for embedding models.""" normalize: bool = True -@dataclass class HuggingFaceEmbeddingModelConfig(HuggingFaceModelConfig): """Base config for HuggingFace-based embedding models.""" diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 1b8d3e68bb..e3a61e9570 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from functools import cached_property from PIL import Image as PILImage @@ -25,7 +24,6 @@ from dimos.msgs.sensor_msgs import Image -@dataclass class CLIPModelConfig(HuggingFaceEmbeddingModelConfig): model_name: str = "openai/clip-vit-base-patch32" dtype: torch.dtype = torch.float32 diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index c02361b367..8ad37936be 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from functools import cached_property from typing import Any @@ -27,7 +26,6 @@ from dimos.utils.data import get_data -@dataclass class MobileCLIPModelConfig(EmbeddingModelConfig): model_name: str = "MobileCLIP2-S4" diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 85e32cd39b..69cc1aae13 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -16,7 +16,6 @@ warnings.filterwarnings("ignore", message="Cython evaluation.*unavailable", category=UserWarning) -from dataclasses import dataclass from functools import cached_property import torch @@ -32,7 +31,6 @@ # osnet models downloaded from https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html # into dimos/data/models_torchreid/ # feel free to add more -@dataclass class TorchReIDModelConfig(EmbeddingModelConfig): model_name: str = "osnet_x1_0" diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 41b240eaf9..1cdeb3f92f 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,21 +1,24 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass import json import logging -from typing import TYPE_CHECKING, Any +import sys +from typing import Any import warnings from dimos.core.resource import Resource from dimos.msgs.sensor_msgs import Image -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.protocol.service.spec import BaseConfig, Configurable from dimos.utils.data import get_data from dimos.utils.decorators import retry from dimos.utils.llm_utils import extract_json -if TYPE_CHECKING: - from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar logger = logging.getLogger(__name__) @@ -159,15 +162,17 @@ def vlm_point_to_detection2d_point( ) -@dataclass -class VlModelConfig: +class VlModelConfig(BaseConfig): """Configuration for VlModel.""" auto_resize: tuple[int, int] | None = None """Optional (width, height) tuple. If set, images are resized to fit.""" -class VlModel(Captioner, Resource, Configurable[VlModelConfig]): +_VlConfig = TypeVar("_VlConfig", bound=VlModelConfig) + + +class VlModel(Captioner, Resource, Configurable[_VlConfig]): """Vision-language model that can answer questions about images. Inherits from Captioner, providing a default caption() implementation @@ -176,8 +181,7 @@ class VlModel(Captioner, Resource, Configurable[VlModelConfig]): Implements Resource interface for lifecycle management. """ - default_config = VlModelConfig - config: VlModelConfig + default_config: type[_VlConfig] = VlModelConfig # type: ignore[assignment] def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 1f8819c8db..6c778d4104 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,11 +1,11 @@ -from typing import Literal +from typing import Any, Literal from dimos.models.vl.base import VlModel VlModelName = Literal["qwen", "moondream"] -def create(name: VlModelName) -> VlModel: +def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. match name: case "qwen": diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index f31611e867..c444d8b9ed 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property from typing import Any import warnings @@ -9,7 +8,7 @@ from transformers import AutoModelForCausalLM # type: ignore[import-untyped] from dimos.models.base import HuggingFaceModel, HuggingFaceModelConfig -from dimos.models.vl.base import VlModel +from dimos.models.vl.base import VlModel, VlModelConfig from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D @@ -17,8 +16,7 @@ MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) -@dataclass -class MoondreamConfig(HuggingFaceModelConfig): +class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): """Configuration for MoondreamVlModel.""" model_name: str = "vikhyatk/moondream2" @@ -26,10 +24,9 @@ class MoondreamConfig(HuggingFaceModelConfig): auto_resize: tuple[int, int] | None = MOONDREAM_DEFAULT_AUTO_RESIZE -class MoondreamVlModel(HuggingFaceModel, VlModel): +class MoondreamVlModel(HuggingFaceModel, VlModel[MoondreamConfig]): _model_class = AutoModelForCausalLM default_config = MoondreamConfig # type: ignore[assignment] - config: MoondreamConfig # type: ignore[assignment] @cached_property def _model(self) -> AutoModelForCausalLM: diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index fc1f8b7a17..57df91b47e 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -6,20 +6,21 @@ import numpy as np from PIL import Image as PILImage -from dimos.models.vl.base import VlModel +from dimos.models.vl.base import VlModel, VlModelConfig from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D -class MoondreamHostedVlModel(VlModel): - _api_key: str | None +class Config(VlModelConfig): + api_key: str | None = None - def __init__(self, api_key: str | None = None) -> None: - self._api_key = api_key + +class MoondreamHostedVlModel(VlModel[Config]): + default_config = Config @cached_property def _client(self) -> md.vl: - api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") + api_key = self.config.api_key or os.getenv("MOONDREAM_API_KEY") if not api_key: raise ValueError( "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index f596f1ee1e..ec774189e4 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -13,15 +12,13 @@ logger = setup_logger() -@dataclass class OpenAIVlModelConfig(VlModelConfig): model_name: str = "gpt-4o-mini" api_key: str | None = None -class OpenAIVlModel(VlModel): +class OpenAIVlModel(VlModel[OpenAIVlModelConfig]): default_config = OpenAIVlModelConfig - config: OpenAIVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 93b31bf74c..014c6f73a5 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -10,7 +9,6 @@ from dimos.msgs.sensor_msgs import Image -@dataclass class QwenVlModelConfig(VlModelConfig): """Configuration for Qwen VL model.""" @@ -18,9 +16,8 @@ class QwenVlModelConfig(VlModelConfig): api_key: str | None = None -class QwenVlModel(VlModel): +class QwenVlModel(VlModel[QwenVlModelConfig]): default_config = QwenVlModelConfig - config: QwenVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index e0752dfd00..170bff9bcd 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -18,7 +18,7 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.msgs.vision_msgs import Detection2DArray @@ -27,17 +27,19 @@ logger = setup_logger(level=logging.DEBUG) -class BBoxNavigationModule(Module): +class Config(ModuleConfig): + goal_distance: float = 1.0 + + +class BBoxNavigationModule(Module[Config]): """Minimal module that converts 2D bbox center to navigation goals.""" + default_config = Config + detection2d: In[Detection2DArray] camera_info: In[CameraInfo] goal_request: Out[PoseStamped] - - def __init__(self, goal_distance: float = 1.0) -> None: - super().__init__() - self.goal_distance = goal_distance - self.camera_intrinsics = None + camera_intrinsics = None @rpc def start(self) -> None: @@ -62,9 +64,9 @@ def _on_detection(self, det: Detection2DArray) -> None: det.detections[0].bbox.center.position.y, ) x, y, z = ( - (center_x - cx) / fx * self.goal_distance, - (center_y - cy) / fy * self.goal_distance, - self.goal_distance, + (center_x - cx) / fx * self.config.goal_distance, + (center_y - cy) / fy * self.config.goal_distance, + self.config.goal_distance, ) goal = PoseStamped( position=Vector3(z, -x, -y), diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 1c8082b414..419986780a 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -262,7 +262,7 @@ def test_frontier_ranking(explorer) -> None: # Note: Goals might be closer than safe_distance if that's the best available frontier # The safe_distance is used for scoring, not as a hard constraint print( - f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.config.safe_distance}m)" ) print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 6e598e8316..f8a5436fc1 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from enum import IntFlag import threading +from typing import Any from dimos_lcm.std_msgs import Bool import numpy as np @@ -30,7 +31,7 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.occupancy.inflation import simple_inflate from dimos.msgs.geometry_msgs import PoseStamped, Vector3 @@ -78,7 +79,18 @@ def clear(self) -> None: self.points.clear() -class WavefrontFrontierExplorer(Module): +class WavefrontConfig(ModuleConfig): + min_frontier_perimeter: float = 0.5 + occupancy_threshold: int = 99 + safe_distance: float = 3.0 + lookahead_distance: float = 5.0 + max_explored_distance: float = 10.0 + info_gain_threshold: float = 0.03 + num_no_gain_attempts: int = 2 + goal_timeout: float = 15.0 + + +class WavefrontFrontierExplorer(Module[WavefrontConfig]): """ Wavefront frontier exploration algorithm implementation. @@ -93,6 +105,8 @@ class WavefrontFrontierExplorer(Module): - goal_request: Exploration goals sent to the navigator """ + default_config = WavefrontConfig + # LCM inputs global_costmap: In[OccupancyGrid] odom: In[PoseStamped] @@ -103,17 +117,7 @@ class WavefrontFrontierExplorer(Module): # LCM outputs goal_request: Out[PoseStamped] - def __init__( - self, - min_frontier_perimeter: float = 0.5, - occupancy_threshold: int = 99, - safe_distance: float = 3.0, - lookahead_distance: float = 5.0, - max_explored_distance: float = 10.0, - info_gain_threshold: float = 0.03, - num_no_gain_attempts: int = 2, - goal_timeout: float = 15.0, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the frontier explorer. @@ -124,20 +128,12 @@ def __init__( info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain """ - super().__init__() - self.min_frontier_perimeter = min_frontier_perimeter - self.occupancy_threshold = occupancy_threshold - self.safe_distance = safe_distance - self.max_explored_distance = max_explored_distance - self.lookahead_distance = lookahead_distance - self.info_gain_threshold = info_gain_threshold - self.num_no_gain_attempts = num_no_gain_attempts + super().__init__(**kwargs) self._cache = FrontierCache() self.explored_goals = [] # type: ignore[var-annotated] # list of explored goals self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction self.last_costmap = None # store last costmap for information comparison self.no_gain_counter = 0 # track consecutive no-gain attempts - self.goal_timeout = goal_timeout # Latest data self.latest_costmap: OccupancyGrid | None = None @@ -214,7 +210,7 @@ def _count_costmap_information(self, costmap: OccupancyGrid) -> int: Number of cells that are free space or obstacles (not unknown) """ free_count = np.sum(costmap.grid == CostValues.FREE) - obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + obstacle_count = np.sum(costmap.grid >= self.config.occupancy_threshold) return int(free_count + obstacle_count) def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: @@ -252,7 +248,7 @@ def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # If adjacent to occupied space, not a frontier - if neighbor_cost > self.occupancy_threshold: + if neighbor_cost > self.config.occupancy_threshold: return False # Check if adjacent to free space @@ -376,7 +372,7 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[ # Check if we found a large enough frontier # Convert minimum perimeter to minimum number of cells based on resolution - min_cells = int(self.min_frontier_perimeter / costmap.resolution) + min_cells = int(self.config.min_frontier_perimeter / costmap.resolution) if len(new_frontier) >= min_cells: world_points = [] for point in new_frontier: @@ -489,7 +485,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr min_distance = float("inf") search_radius = ( - int(self.safe_distance / costmap.resolution) + 5 + int(self.config.safe_distance / costmap.resolution) + 5 ) # Search a bit beyond minimum # Search in a square around the frontier point @@ -508,14 +504,14 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr continue # Check if this cell is an obstacle - if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + if costmap.grid[check_y, check_x] >= self.config.occupancy_threshold: # Calculate distance in meters distance = np.sqrt(dx**2 + dy**2) * costmap.resolution min_distance = min(min_distance, distance) # If no obstacles found within search radius, return the safe distance # This indicates the frontier is safely away from obstacles - return min_distance if min_distance != float("inf") else self.safe_distance + return min_distance if min_distance != float("inf") else self.config.safe_distance def _compute_comprehensive_frontier_score( self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid @@ -527,25 +523,25 @@ def _compute_comprehensive_frontier_score( # Distance score: prefer moderate distances (not too close, not too far) # Normalized to 0-1 range - distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) + distance_score = 1.0 / (1.0 + abs(robot_distance - self.config.lookahead_distance)) # 2. Information gain (frontier size) # Normalize by a reasonable max frontier size - max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 + max_expected_frontier_size = self.config.min_frontier_perimeter / costmap.resolution * 10 info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) # 3. Distance to explored goals (bonus for being far from explored areas) # Normalize by a reasonable max distance (e.g., 10 meters) explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) + explored_goals_score = min(explored_goals_distance / self.config.max_explored_distance, 1.0) # 4. Distance to obstacles (score based on safety) # 0 = too close to obstacles, 1 = at or beyond safe distance obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacles_distance >= self.safe_distance: + if obstacles_distance >= self.config.safe_distance: obstacles_score = 1.0 # Fully safe else: - obstacles_score = obstacles_distance / self.safe_distance # Linear penalty + obstacles_score = obstacles_distance / self.config.safe_distance # Linear penalty # 5. Direction momentum (already in 0-1 range from dot product) momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) @@ -628,15 +624,15 @@ def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> V # Check if information increase meets minimum percentage threshold if last_info > 0: # Avoid division by zero info_increase_percent = (current_info - last_info) / last_info - if info_increase_percent < self.info_gain_threshold: + if info_increase_percent < self.config.info_gain_threshold: logger.info( - f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.config.info_gain_threshold:.2f})" ) logger.info( f"Current information: {current_info}, Last information: {last_info}" ) self.no_gain_counter += 1 - if self.no_gain_counter >= self.num_no_gain_attempts: + if self.no_gain_counter >= self.config.num_no_gain_attempts: logger.info( f"No information gain for {self.no_gain_counter} consecutive attempts" ) @@ -797,7 +793,7 @@ def _exploration_loop(self) -> None: # Wait for goal to be reached or timeout logger.info("Waiting for goal to be reached...") - goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) + goal_reached = self.goal_reached_event.wait(timeout=self.config.goal_timeout) if goal_reached: logger.info("Goal reached, finding next frontier") diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 4dad9a2843..28a22a2a86 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -13,12 +13,12 @@ # limitations under the License. import os +from typing import Any from dimos_lcm.std_msgs import Bool, String from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import PointStamped, PoseStamped, Twist @@ -41,12 +41,10 @@ class ReplanningAStarPlanner(Module, NavigationInterface): navigation_costmap: Out[OccupancyGrid] _planner: GlobalPlanner - _global_config: GlobalConfig - def __init__(self, cfg: GlobalConfig = global_config) -> None: - super().__init__() - self._global_config = cfg - self._planner = GlobalPlanner(self._global_config) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._planner = GlobalPlanner(self.config.g) @rpc def start(self) -> None: diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 89b299ae5b..230d94b50f 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -18,11 +18,12 @@ Encapsulates ROS transport and topic remapping for Unitree robots. """ -from dataclasses import dataclass, field import logging import threading import time +from typing import Any +from pydantic import Field from reactivex import operators as ops from reactivex.subject import Subject @@ -52,19 +53,22 @@ logger = setup_logger(level=logging.INFO) -@dataclass class Config(ModuleConfig): local_pointcloud_freq: float = 2.0 global_map_freq: float = 1.0 - sensor_to_base_link_transform: Transform = field( + sensor_to_base_link_transform: Transform = Field( default_factory=lambda: Transform(frame_id="sensor", child_frame_id="base_link") ) class ROSNav( - Module, NavigationInterface, spec.Nav, spec.GlobalPointcloud, spec.Pointcloud, spec.LocalPlanner + Module[Config], + NavigationInterface, + spec.Nav, + spec.GlobalPointcloud, + spec.Pointcloud, + spec.LocalPlanner, ): - config: Config default_config = Config # Existing ports (default LCM/pSHM transport) @@ -106,8 +110,8 @@ class ROSNav( _current_goal: PoseStamped | None = None _goal_reached: bool = False - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # Initialize RxPY Subjects for streaming data self._local_pointcloud_subject = Subject() diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 37b743506a..0c84e8ac34 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from dimos.models.qwen.bbox import BBox from dimos.models.vl.base import VlModel @@ -20,7 +21,7 @@ def get_object_bbox_from_image( - vl_model: VlModel, image: Image, object_description: str + vl_model: VlModel[Any], image: Image, object_description: str ) -> BBox | None: prompt = ( f"Look at this image and find the '{object_description}'. " diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index e81ab2ab4a..8c1a65eb8b 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -15,6 +15,7 @@ from collections.abc import Callable, Generator import functools from typing import TypedDict +from unittest import mock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate @@ -204,7 +205,8 @@ def detection3dpc(detections3dpc) -> Detection3DPC: def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: from dimos.perception.detection.detectors import Yolo2DDetector - module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) @functools.lru_cache(maxsize=1) def moment_provider(**kwargs) -> Moment2D: @@ -262,7 +264,8 @@ def object_db_module(get_moment): """Create and populate an ObjectDBModule with detections from multiple frames.""" from dimos.perception.detection.detectors import Yolo2DDetector - module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) module3d = Detection3DModule(camera_info=connection._camera_info_static()) moduleDB = ObjectDBModule(camera_info=connection._camera_info_static()) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index f86794a1f7..0a07b1238d 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any +from collections.abc import Callable, Sequence +from typing import Annotated, Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) +from pydantic.experimental.pipeline import validate_as from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject @@ -38,24 +38,21 @@ from dimos.utils.reactive import backpressure -@dataclass class Config(ModuleConfig): max_freq: float = 10 detector: Callable[[Any], Detector] | None = Yolo2DDetector publish_detection_images: bool = True - camera_info: CameraInfo = None # type: ignore[assignment] - filter: list[Filter2D] | Filter2D | None = None + camera_info: CameraInfo + filter: Annotated[ + Sequence[Filter2D], + validate_as(Sequence[Filter2D] | Filter2D).transform( + lambda f: f if isinstance(f, Sequence) else (f,) + ), + ] = () - def __post_init__(self) -> None: - if self.filter is None: - self.filter = [] - elif not isinstance(self.filter, list): - self.filter = [self.filter] - -class Detection2DModule(Module): +class Detection2DModule(Module[Config]): default_config = Config - config: Config detector: Detector color_image: In[Image] diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 0d5531bada..a2f5b41cbf 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -557,7 +557,7 @@ def estimate_and_save_distances( self, parsed: dict[str, Any], frame_image: Image, - vlm: VlModel, + vlm: VlModel[Any], timestamp_s: float, max_distance_pairs: int = 5, ) -> None: diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index b651d3e0af..7d01522417 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -22,13 +22,12 @@ from __future__ import annotations from collections import deque -from dataclasses import dataclass import json import os from pathlib import Path import threading import time -from typing import TYPE_CHECKING, Any +from typing import Any from reactivex import Subject, interval from reactivex.disposable import Disposable @@ -37,6 +36,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out +from dimos.models.vl.base import VlModel from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier @@ -50,9 +50,6 @@ from .temporal_state import TemporalState from .window_analyzer import WindowAnalyzer -if TYPE_CHECKING: - from dimos.models.vl.base import VlModel - try: from .clip_filter import CLIPFrameFilter except ImportError: @@ -63,7 +60,6 @@ MAX_RECENT_WINDOWS = 50 -@dataclass class TemporalMemoryConfig(ModuleConfig): """Configuration for the temporal memory module. @@ -71,6 +67,8 @@ class TemporalMemoryConfig(ModuleConfig): tune cost / latency / accuracy without touching code. """ + vlm: VlModel[Any] | None = None + # Frame processing fps: float = 1.0 window_s: float = 5.0 @@ -106,38 +104,35 @@ class TemporalMemoryConfig(ModuleConfig): nearby_distance_meters: float = 5.0 -class TemporalMemory(Module): +class TemporalMemory(Module[TemporalMemoryConfig]): """Thin orchestrator that wires frames → window accumulator → VLM → state + DB. Uses RxPY reactive streams for the frame pipeline and ``interval`` for periodic window analysis. """ + default_config = TemporalMemoryConfig + color_image: In[Image] odom: In[PoseStamped] entity_markers: Out[EntityMarkers] - def __init__( - self, - vlm: VlModel | None = None, - config: TemporalMemoryConfig | None = None, - ) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - self._vlm_raw = vlm - self._config: TemporalMemoryConfig = config or TemporalMemoryConfig() + self._vlm_raw = self.config.vlm # new_memory is set via TemporalMemoryConfig by the blueprint factory # (which runs in the main process where GlobalConfig is available). # Components self._accumulator = FrameWindowAccumulator( - max_buffer_frames=self._config.max_buffer_frames, - window_s=self._config.window_s, - stride_s=self._config.stride_s, - fps=self._config.fps, + max_buffer_frames=self.config.max_buffer_frames, + window_s=self.config.window_s, + stride_s=self.config.stride_s, + fps=self.config.fps, ) - self._state = TemporalState(next_summary_at_s=self._config.summary_interval_s) + self._state = TemporalState(next_summary_at_s=self.config.summary_interval_s) self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) self._stopped = False @@ -150,10 +145,10 @@ def __init__( # CLIP filter self._clip_filter: CLIPFrameFilter | None = None - self._use_clip_filtering = self._config.use_clip_filtering + self._use_clip_filtering = self.config.use_clip_filtering if self._use_clip_filtering and CLIP_AVAILABLE: try: - self._clip_filter = CLIPFrameFilter(model_name=self._config.clip_model) + self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model) logger.info("clip filtering enabled") except Exception as e: logger.warning(f"clip init failed: {e}") @@ -163,8 +158,8 @@ def __init__( self._use_clip_filtering = False # Persistent DB — stored in XDG state dir (same root as per-run logs) - if self._config.db_dir: - db_dir = Path(self._config.db_dir) + if self.config.db_dir: + db_dir = Path(self.config.db_dir) else: # Default: ~/.local/state/dimos/temporal_memory/ # XDG state dir — predictable, works for pip install and git clone. @@ -173,7 +168,7 @@ def __init__( db_dir = state_root / "dimos" / "temporal_memory" db_dir.mkdir(parents=True, exist_ok=True) db_path = db_dir / "entity_graph.db" - if self._config.new_memory and db_path.exists(): + if self.config.new_memory and db_path.exists(): db_path.unlink() logger.info("Deleted existing DB (new_memory=True)") self._graph_db = EntityGraphDB(db_path=db_path) @@ -181,7 +176,7 @@ def __init__( # Persistent JSONL — accumulates across runs (raw VLM output + parsed) self._persistent_jsonl_path: Path = db_dir / "temporal_memory.jsonl" - if self._config.new_memory and self._persistent_jsonl_path.exists(): + if self.config.new_memory and self._persistent_jsonl_path.exists(): self._persistent_jsonl_path.unlink() logger.info("Deleted existing persistent JSONL (new_memory=True)") logger.info(f"persistent JSONL: {self._persistent_jsonl_path}") @@ -204,8 +199,8 @@ def __init__( logger.warning("no run log dir found — JSONL logging disabled") logger.info( - f"TemporalMemory init: fps={self._config.fps}, " - f"window={self._config.window_s}s, stride={self._config.stride_s}s" + f"TemporalMemory init: fps={self.config.fps}, " + f"window={self.config.window_s}s, stride={self.config.stride_s}s" ) # ------------------------------------------------------------------ @@ -213,7 +208,7 @@ def __init__( # ------------------------------------------------------------------ @property - def vlm(self) -> VlModel: + def vlm(self) -> VlModel[Any]: if self._vlm_raw is None: from dimos.models.vl.openai import OpenAIVlModel @@ -230,8 +225,8 @@ def _analyzer(self) -> WindowAnalyzer: if not hasattr(self, "__analyzer"): self.__analyzer = WindowAnalyzer( self.vlm, - max_tokens=self._config.max_tokens, - temperature=self._config.temperature, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, ) return self.__analyzer @@ -261,7 +256,7 @@ def _log_jsonl(self, record: dict[str, Any]) -> None: def _publish_entity_markers(self) -> None: """Publish entity positions as 3D markers for Rerun overlay on the map.""" - if not self._config.visualize: + if not self.config.visualize: return try: all_entities = self._graph_db.get_all_entities() @@ -319,7 +314,7 @@ def _on_frame(img: Image) -> None: ) self._disposables.add( - frame_subject.pipe(sharpness_barrier(self._config.fps)).subscribe(_on_frame) + frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(_on_frame) ) unsub_image = self.color_image.subscribe(frame_subject.on_next) self._disposables.add(Disposable(unsub_image)) @@ -342,7 +337,7 @@ def _on_odom(msg: PoseStamped) -> None: # Periodic window analysis self._disposables.add( - interval(self._config.stride_s).subscribe(lambda _: self._analyze_window()) + interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) ) logger.info("TemporalMemory started") @@ -366,7 +361,7 @@ def stop(self) -> None: self._accumulator.clear() self._recent_windows.clear() - self._state.clear(self._config.summary_interval_s) + self._state.clear(self.config.summary_interval_s) super().stop() @@ -401,13 +396,13 @@ def _analyze_window(self) -> None: w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s # Skip stale scenes (frames too close together / camera not moving) - if tu.is_scene_stale(window_frames, self._config.stale_scene_threshold): + if tu.is_scene_stale(window_frames, self.config.stale_scene_threshold): logger.info(f"[temporal-memory] skipping stale window [{w_start:.1f}-{w_end:.1f}s]") return # Select diverse keyframes window_frames = adaptive_keyframes( - window_frames, max_frames=self._config.max_frames_per_window + window_frames, max_frames=self.config.max_frames_per_window ) logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") @@ -458,7 +453,7 @@ def _analyze_window(self) -> None: ) # VLM Call #2: distance estimation (background thread) - if self._graph_db and self._config.enable_distance_estimation and window_frames: + if self._graph_db and self.config.enable_distance_estimation and window_frames: mid_frame = window_frames[len(window_frames) // 2] if mid_frame.image: thread = threading.Thread( @@ -468,7 +463,7 @@ def _analyze_window(self) -> None: mid_frame.image, self.vlm, w_end, - self._config.max_distance_pairs, + self.config.max_distance_pairs, ), daemon=True, ) @@ -478,7 +473,7 @@ def _analyze_window(self) -> None: # Update state needs_summary = self._state.update_from_window( - parsed, w_end, self._config.summary_interval_s + parsed, w_end, self.config.summary_interval_s ) self._recent_windows.append(parsed) @@ -512,7 +507,7 @@ def _update_rolling_summary(self, w_end: float) -> None: sr = self._analyzer.update_summary(latest.image, snap.rolling_summary, snap.chunk_buffer) if sr is not None: - self._state.apply_summary(sr.summary_text, w_end, self._config.summary_interval_s) + self._state.apply_summary(sr.summary_text, w_end, self.config.summary_interval_s) self._log_jsonl( { "ts": time.time(), @@ -592,8 +587,8 @@ def query(self, question: str) -> str: graph_db=self._graph_db, entity_ids=all_entity_ids, time_window_s=time_window_s, - max_relations_per_entity=self._config.max_relations_per_entity, - nearby_distance_meters=self._config.nearby_distance_meters, + max_relations_per_entity=self.config.max_relations_per_entity, + nearby_distance_meters=self.config.nearby_distance_meters, current_video_time_s=current_video_time_s, ) context["graph_knowledge"] = graph_context @@ -623,7 +618,7 @@ def query(self, question: str) -> str: @rpc def clear_history(self) -> bool: try: - self._state.clear(self._config.summary_interval_s) + self._state.clear(self.config.summary_interval_s) self._recent_windows.clear() logger.info("cleared history") return True diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index d98074bd5d..abaa99dede 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -21,7 +21,7 @@ import threading import time from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch from dotenv import load_dotenv import numpy as np @@ -33,6 +33,7 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import Out from dimos.core.transport import LCMTransport +from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.perception.experimental.temporal_memory import ( Frame, @@ -337,8 +338,9 @@ def test_new_memory_clears_db(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), new_memory=True), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + new_memory=True, ) # DB should be empty since we cleared it stats = tm._graph_db.get_stats() @@ -361,8 +363,9 @@ def test_persistent_memory_survives(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), new_memory=False), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + new_memory=False, ) stats = tm._graph_db.get_stats() assert stats["entities"] == 1 @@ -386,8 +389,8 @@ def test_log_entries(self, tmp_path: Path) -> None: return_value=log_dir, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir)), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), ) jsonl_path = log_dir / "temporal_memory" / "temporal_memory.jsonl" @@ -427,8 +430,9 @@ def test_publish_entity_markers(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), visualize=True), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + visualize=True, ) # Populate DB with world positions @@ -487,7 +491,7 @@ class TestWindowAnalyzer: def test_analyze_window_calls_vlm(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = json.dumps( { "window": {"start_s": 0.0, "end_s": 2.0}, @@ -513,7 +517,7 @@ def test_analyze_window_calls_vlm(self) -> None: def test_analyze_window_vlm_error(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.side_effect = RuntimeError("VLM error") analyzer = WindowAnalyzer(mock_vlm) @@ -527,7 +531,7 @@ def test_analyze_window_vlm_error(self) -> None: def test_update_summary(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = "Updated summary text" analyzer = WindowAnalyzer(mock_vlm) @@ -540,7 +544,7 @@ def test_update_summary(self) -> None: def test_answer_query(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = "The answer is 42" analyzer = WindowAnalyzer(mock_vlm) diff --git a/dimos/perception/experimental/temporal_memory/window_analyzer.py b/dimos/perception/experimental/temporal_memory/window_analyzer.py index a8b1899258..70bfec8d74 100644 --- a/dimos/perception/experimental/temporal_memory/window_analyzer.py +++ b/dimos/perception/experimental/temporal_memory/window_analyzer.py @@ -68,13 +68,15 @@ class WindowAnalyzer: Stateless — caller provides frames, state snapshots, and config. """ - def __init__(self, vlm: VlModel, *, max_tokens: int = 900, temperature: float = 0.2) -> None: + def __init__( + self, vlm: VlModel[Any], *, max_tokens: int = 900, temperature: float = 0.2 + ) -> None: self._vlm = vlm self.max_tokens = max_tokens self.temperature = temperature @property - def vlm(self) -> VlModel: + def vlm(self) -> VlModel[Any]: return self._vlm # ------------------------------------------------------------------ diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index da415ac32a..29a9ecc034 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import threading import time +from typing import Any import cv2 @@ -51,9 +51,10 @@ logger = setup_logger() -@dataclass class ObjectTrackingConfig(ModuleConfig): frame_id: str = "camera_link" + reid_threshold: int = 10 + reid_fail_tolerance: int = 5 class ObjectTracking(Module[ObjectTrackingConfig]): @@ -70,11 +71,8 @@ class ObjectTracking(Module[ObjectTrackingConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTrackingConfig - config: ObjectTrackingConfig - def __init__( - self, reid_threshold: int = 10, reid_fail_tolerance: int = 5, **kwargs: object - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. @@ -89,8 +87,6 @@ def __init__( super().__init__(**kwargs) self.camera_intrinsics = None - self.reid_threshold = reid_threshold - self.reid_fail_tolerance = reid_fail_tolerance self.tracker = None self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization @@ -276,7 +272,7 @@ def reid(self, frame, current_bbox) -> bool: # type: ignore[no-untyped-def] good_matches += 1 self.last_good_matches = good_matches_list # Store good matches for visualization - return good_matches >= self.reid_threshold + return good_matches >= self.config.reid_threshold def _start_tracking_thread(self) -> None: """Start the tracking thread.""" @@ -389,7 +385,7 @@ def _process_tracking(self) -> None: # Determine final success if tracker_succeeded: - if self.reid_fail_count >= self.reid_fail_tolerance: + if self.reid_fail_count >= self.config.reid_fail_tolerance: logger.warning( f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." ) @@ -589,11 +585,11 @@ def _draw_reid_matches(self, image: NDArray[np.uint8]) -> NDArray[np.uint8]: # f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" ) status_color = (255, 255, 0) # Yellow - elif len(self.last_good_matches) >= self.reid_threshold: + elif len(self.last_good_matches) >= self.config.reid_threshold: status_text = "REID: CONFIRMED" status_color = (0, 255, 0) # Green else: - status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.config.reid_fail_tolerance})" status_color = (0, 165, 255) # Orange cv2.putText( diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 1264b0e92b..03f3991081 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import logging import threading import time +from typing import Any import cv2 @@ -43,7 +43,6 @@ logger = setup_logger(level=logging.INFO) -@dataclass class ObjectTracker2DConfig(ModuleConfig): frame_id: str = "camera_link" @@ -57,9 +56,8 @@ class ObjectTracker2D(Module[ObjectTracker2DConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTracker2DConfig - config: ObjectTracker2DConfig - def __init__(self, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize 2D object tracking module using OpenCV's CSRT tracker.""" super().__init__(**kwargs) diff --git a/dimos/perception/perceive_loop_skill.py b/dimos/perception/perceive_loop_skill.py index 53362977f5..0d84e40897 100644 --- a/dimos/perception/perceive_loop_skill.py +++ b/dimos/perception/perceive_loop_skill.py @@ -16,7 +16,7 @@ import json from threading import RLock -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langchain_core.messages import HumanMessage @@ -34,8 +34,6 @@ if TYPE_CHECKING: from reactivex.abc import DisposableBase - from dimos.core.global_config import GlobalConfig - from dimos.models.vl.base import VlModel logger = setup_logger() @@ -46,13 +44,9 @@ class PerceiveLoopSkill(Module): _agent_spec: AgentSpec _period: float = 0.5 # seconds - how often to run the perceive loop - def __init__( - self, - cfg: GlobalConfig, - ) -> None: - super().__init__() - self._global_config: GlobalConfig = cfg - self._vl_model: VlModel = create(cfg.detection_model) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._vl_model = create(self.config.g.detection_model) self._active_lookout: tuple[str, ...] = () self._lookout_subscription: DisposableBase | None = None self._model_started: bool = False diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index bf62d50bcf..0cb4ab74c1 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -19,7 +19,7 @@ from datetime import datetime import os import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import uuid import cv2 @@ -33,7 +33,7 @@ from dimos.agents_deprecated.memory.visual_memory import VisualMemory from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In from dimos.msgs.sensor_msgs import Image @@ -53,7 +53,23 @@ logger = setup_logger() -class SpatialMemory(Module): +class SpatialConfig(ModuleConfig): + collection_name: str = "spatial_memory" + embedding_model: str = "clip" + embedding_dimensions: int = 512 + min_distance_threshold: float = 0.01 # Min distance in meters to store a new frame + min_time_threshold: float = 1.0 # Min time in seconds to record a new frame + db_path: str | None = str(_DB_PATH) # Path for ChromaDB persistence + visual_memory_path: str | None = str( + _VISUAL_MEMORY_PATH + ) # Path for saving/loading visual memory + new_memory: bool = True # Whether to create a new memory from scratch + output_dir: str | None = str(_SPATIAL_MEMORY_DIR) # Directory for storing visual memory data + chroma_client: Any = None # Optional ChromaDB client for persistence + visual_memory: VisualMemory | None = None # Optional VisualMemory instance for storing images + + +class SpatialMemory(Module[SpatialConfig]): """ A Dimos module for building and querying Robot spatial memory. @@ -63,29 +79,12 @@ class SpatialMemory(Module): robot locations that can be queried by name. """ + default_config = SpatialConfig + # LCM inputs color_image: In[Image] - def __init__( - self, - collection_name: str = "spatial_memory", - embedding_model: str = "clip", - embedding_dimensions: int = 512, - min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame - min_time_threshold: float = 1.0, # Min time in seconds to record a new frame - db_path: str | None = str(_DB_PATH), # Path for ChromaDB persistence - visual_memory_path: str | None = str( - _VISUAL_MEMORY_PATH - ), # Path for saving/loading visual memory - new_memory: bool = True, # Whether to create a new memory from scratch - output_dir: str | None = str( - _SPATIAL_MEMORY_DIR - ), # Directory for storing visual memory data - chroma_client: Any = None, # Optional ChromaDB client for persistence - visual_memory: Optional[ - "VisualMemory" - ] = None, # Optional VisualMemory instance for storing images - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the spatial perception system. @@ -99,39 +98,36 @@ def __init__( visual_memory: Optional VisualMemory instance for storing images output_dir: Directory for storing visual memory data if visual_memory is not provided """ - self.collection_name = collection_name - self.embedding_model = embedding_model - self.embedding_dimensions = embedding_dimensions - self.min_distance_threshold = min_distance_threshold - self.min_time_threshold = min_time_threshold - - # Set up paths for persistence - # Call parent Module init - super().__init__() + super().__init__(**kwargs) - self.db_path = db_path - self.visual_memory_path = visual_memory_path + self.collection_name = self.config.collection_name + self.embedding_model = self.config.embedding_model + self.embedding_dimensions = self.config.embedding_dimensions + self.min_distance_threshold = self.config.min_distance_threshold + self.min_time_threshold = self.config.min_time_threshold + self.db_path = self.config.db_path + self.visual_memory_path = self.config.visual_memory_path # Setup ChromaDB client if not provided - self._chroma_client = chroma_client - if chroma_client is None and db_path is not None: + self._chroma_client = self.config.chroma_client + if self._chroma_client is None and self.db_path is not None: # Create db directory if needed - os.makedirs(db_path, exist_ok=True) + os.makedirs(self.db_path, exist_ok=True) # Clean up existing DB if creating new memory - if new_memory and os.path.exists(db_path): + if self.config.new_memory and os.path.exists(self.db_path): try: logger.info("Creating new ChromaDB database (new_memory=True)") # Try to delete any existing database files import shutil - for item in os.listdir(db_path): - item_path = os.path.join(db_path, item) + for item in os.listdir(self.db_path): + item_path = os.path.join(self.db_path, item) if os.path.isfile(item_path): os.unlink(item_path) elif os.path.isdir(item_path): shutil.rmtree(item_path) - logger.info(f"Removed existing ChromaDB files from {db_path}") + logger.info(f"Removed existing ChromaDB files from {self.db_path}") except Exception as e: logger.error(f"Error clearing ChromaDB directory: {e}") @@ -139,33 +135,33 @@ def __init__( from chromadb.config import Settings self._chroma_client = chromadb.PersistentClient( - path=db_path, settings=Settings(anonymized_telemetry=False) + path=self.db_path, settings=Settings(anonymized_telemetry=False) ) # Initialize or load visual memory - self._visual_memory = visual_memory - if visual_memory is None: - if new_memory or not os.path.exists(visual_memory_path or ""): + self._visual_memory = self.config.visual_memory + if self._visual_memory is None: + if self.config.new_memory or not os.path.exists(self.visual_memory_path or ""): logger.info("Creating new visual memory") - self._visual_memory = VisualMemory(output_dir=output_dir) + self._visual_memory = VisualMemory(output_dir=self.config.output_dir) else: try: - logger.info(f"Loading existing visual memory from {visual_memory_path}...") + logger.info(f"Loading existing visual memory from {self.visual_memory_path}...") self._visual_memory = VisualMemory.load( - visual_memory_path, # type: ignore[arg-type] - output_dir=output_dir, + self.visual_memory_path, # type: ignore[arg-type] + output_dir=self.config.output_dir, ) logger.info(f"Loaded {self._visual_memory.count()} images from previous runs") except Exception as e: logger.error(f"Error loading visual memory: {e}") - self._visual_memory = VisualMemory(output_dir=output_dir) + self._visual_memory = VisualMemory(output_dir=self.config.output_dir) self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( - model_name=embedding_model, dimensions=embedding_dimensions + model_name=self.embedding_model, dimensions=self.embedding_dimensions ) self.vector_db: SpatialVectorDB = SpatialVectorDB( - collection_name=collection_name, + collection_name=self.collection_name, chroma_client=self._chroma_client, visual_memory=self._visual_memory, embedding_provider=self.embedding_provider, @@ -184,7 +180,7 @@ def __init__( self._latest_video_frame: np.ndarray | None = None # type: ignore[type-arg] self._process_interval = 1 - logger.info(f"SpatialMemory initialized with model {embedding_model}") + logger.info(f"SpatialMemory initialized with model {self.embedding_model}") @rpc def start(self) -> None: diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index ac9b132a69..22aa4d4ce8 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -20,7 +20,7 @@ from reactivex import operators as ops from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import Out from dimos.core.transport import LCMTransport @@ -35,21 +35,22 @@ logger = setup_logger() -class VideoReplayModule(Module): +class VideoReplayConfig(ModuleConfig): + video_path: str + + +class VideoReplayModule(Module[VideoReplayConfig]): """Module that replays video data from TimedSensorReplay.""" + default_config = VideoReplayConfig video_out: Out[Image] - - def __init__(self, video_path: str) -> None: - super().__init__() - self.video_path = video_path - self._subscription = None + _subscription = None @rpc def start(self) -> None: """Start replaying video data.""" # Use TimedSensorReplay to replay video frames - video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) + video_replay = TimedSensorReplay(self.config.video_path, autocast=Image.from_numpy) # Subscribe to the replay stream and publish to LCM self._subscription = ( diff --git a/dimos/protocol/pubsub/bridge.py b/dimos/protocol/pubsub/bridge.py index f312caed7b..72cbe155d9 100644 --- a/dimos/protocol/pubsub/bridge.py +++ b/dimos/protocol/pubsub/bridge.py @@ -16,10 +16,9 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Protocol, TypeVar -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service if TYPE_CHECKING: from collections.abc import Callable @@ -66,8 +65,7 @@ def pass_msg(msg: MsgFrom, topic: TopicFrom) -> None: return pubsub1.subscribe_all(pass_msg) -@dataclass -class BridgeConfig(Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): +class BridgeConfig(BaseConfig, Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): """Configuration for a one-way bridge.""" source: AllPubSub[TopicFrom, MsgFrom] diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index bf6bbd0dec..4e792f5965 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -14,10 +14,13 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass import re -from typing import TYPE_CHECKING, Any +import threading +from typing import Any +from dimos.msgs import DimosMsg from dimos.protocol.pubsub.encoders import ( JpegEncoderMixin, LCMEncoderMixin, @@ -25,15 +28,9 @@ ) from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf +from dimos.protocol.service.lcmservice import LCMService, autoconf from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from collections.abc import Callable - import threading - - from dimos.msgs import DimosMsg - logger = setup_logger() @@ -83,7 +80,6 @@ class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]): RegexSubscribable directly without needing discovery-based fallback. """ - default_config = LCMConfig _stop_event: threading.Event _thread: threading.Thread | None diff --git a/dimos/protocol/pubsub/impl/redispubsub.py b/dimos/protocol/pubsub/impl/redispubsub.py index 6cc089e953..b299d6b883 100644 --- a/dimos/protocol/pubsub/impl/redispubsub.py +++ b/dimos/protocol/pubsub/impl/redispubsub.py @@ -14,25 +14,24 @@ from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass, field import json import threading import time from types import TracebackType from typing import Any +from pydantic import Field import redis # type: ignore[import-not-found] from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service -@dataclass -class RedisConfig: +class RedisConfig(BaseConfig): host: str = "localhost" port: int = 6379 db: int = 0 - kwargs: dict[str, Any] = field(default_factory=dict) + kwargs: dict[str, Any] = Field(default_factory=dict) class Redis(PubSub[str, Any], Service[RedisConfig]): diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py index fb9df08ca9..ed6caf93c2 100644 --- a/dimos/protocol/service/__init__.py +++ b/dimos/protocol/service/__init__.py @@ -1,8 +1,9 @@ from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import Configurable as Configurable, Service as Service +from dimos.protocol.service.spec import BaseConfig, Configurable, Service -__all__ = [ +__all__ = ( + "BaseConfig", "Configurable", "LCMService", "Service", -] +) diff --git a/dimos/protocol/service/ddsservice.py b/dimos/protocol/service/ddsservice.py index 6ed04c07ad..b5562defff 100644 --- a/dimos/protocol/service/ddsservice.py +++ b/dimos/protocol/service/ddsservice.py @@ -14,9 +14,8 @@ from __future__ import annotations -from dataclasses import dataclass import threading -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING try: from cyclonedds.domain import DomainParticipant @@ -26,7 +25,7 @@ DDS_AVAILABLE = False DomainParticipant = None # type: ignore[assignment, misc] -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -38,8 +37,7 @@ _participants_lock = threading.Lock() -@dataclass -class DDSConfig: +class DDSConfig(BaseConfig): """Configuration for DDS service.""" domain_id: int = 0 @@ -49,9 +47,6 @@ class DDSConfig: class DDSService(Service[DDSConfig]): default_config = DDSConfig - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - def start(self) -> None: """Start the DDS service.""" domain_id = self.config.domain_id diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 5cd4563fd1..9a563addb1 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -15,18 +15,24 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass import os import platform +import sys import threading import traceback +from typing import Any -import lcm +import lcm as lcm_mod -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service from dimos.protocol.service.system_configurator import configure_system, lcm_configurators from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() _DEFAULT_LCM_HOST = "239.255.76.67" @@ -45,40 +51,37 @@ def autoconf(check_only: bool = False) -> None: configure_system(checks, check_only=check_only) -@dataclass -class LCMConfig: +class LCMConfig(BaseConfig): ttl: int = 0 - url: str | None = None - lcm: lcm.LCM | None = None - - def __post_init__(self) -> None: - if self.url is None: - self.url = _DEFAULT_LCM_URL + url: str = _DEFAULT_LCM_URL + lcm: lcm_mod.LCM | None = None +_Config = TypeVar("_Config", bound=LCMConfig, default=LCMConfig) _LCM_LOOP_TIMEOUT = 50 # this class just sets up cpp LCM instance # and runs its handle loop in a thread # higher order stuff is done by pubsub/impl/lcmpubsub.py -class LCMService(Service[LCMConfig]): - default_config = LCMConfig - l: lcm.LCM | None +class LCMService(Service[_Config]): + default_config = LCMConfig # type: ignore[assignment] + + l: lcm_mod.LCM | None _stop_event: threading.Event _l_lock: threading.Lock _thread: threading.Thread | None _call_thread_pool: ThreadPoolExecutor | None = None _call_thread_pool_lock: threading.RLock = threading.RLock() - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) # we support passing an existing LCM instance if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() self._l_lock = threading.Lock() self._stop_event = threading.Event() @@ -113,7 +116,7 @@ def start(self) -> None: if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() self._stop_event.clear() self._thread = threading.Thread(target=self._lcm_loop) diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index c4e6758614..c9796cf2b5 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -13,17 +13,24 @@ # limitations under the License. from abc import ABC -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar + +from pydantic import BaseModel + + +class BaseConfig(BaseModel): + model_config = {"arbitrary_types_allowed": True, "extra": "forbid"} + # Generic type for service configuration -ConfigT = TypeVar("ConfigT") +ConfigT = TypeVar("ConfigT", bound=BaseConfig) class Configurable(Generic[ConfigT]): default_config: type[ConfigT] - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - self.config: ConfigT = self.default_config(**kwargs) + def __init__(self, **kwargs: Any) -> None: + self.config = self.default_config(**kwargs) class Service(Configurable[ConfigT], ABC): diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index 857bc305a2..a647c89c86 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -14,7 +14,9 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch + +from lcm import LCM from dimos.protocol.pubsub.impl.lcmpubsub import Topic from dimos.protocol.service.lcmservice import ( @@ -100,10 +102,6 @@ def test_custom_url(self) -> None: config = LCMConfig(url=custom_url) assert config.url == custom_url - def test_post_init_sets_default_url_when_none(self) -> None: - config = LCMConfig(url=None) - assert config.url == _DEFAULT_LCM_URL - # ----------------------------- Topic tests ----------------------------- @@ -125,8 +123,8 @@ def test_str_with_lcm_type(self) -> None: class TestLCMService: def test_init_with_default_config(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -136,8 +134,8 @@ def test_init_with_default_config(self) -> None: def test_init_with_custom_url(self) -> None: custom_url = "udpm://192.168.1.1:7777?ttl=1" - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance # Pass url as kwarg, not config= @@ -145,17 +143,17 @@ def test_init_with_custom_url(self) -> None: mock_lcm_class.assert_called_once_with(custom_url) def test_init_with_existing_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) mock_lcm_class.assert_not_called() assert service.l == mock_lcm_instance def test_start_and_stop(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -172,8 +170,8 @@ def test_start_and_stop(self) -> None: assert not service._thread.is_alive() def test_getstate_excludes_unpicklable_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -187,8 +185,8 @@ def test_getstate_excludes_unpicklable_attrs(self) -> None: assert "_call_thread_pool_lock" not in state def test_setstate_reinitializes_runtime_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -207,8 +205,8 @@ def test_setstate_reinitializes_runtime_attrs(self) -> None: assert hasattr(new_service._l_lock, "release") def test_start_reinitializes_lcm_after_unpickling(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -227,8 +225,8 @@ def test_start_reinitializes_lcm_after_unpickling(self) -> None: new_service.stop() def test_stop_cleans_up_lcm_instance(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -239,7 +237,7 @@ def test_stop_cleans_up_lcm_instance(self) -> None: assert service.l is None def test_stop_preserves_external_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) @@ -250,8 +248,8 @@ def test_stop_preserves_external_lcm_instance(self) -> None: assert service.l == mock_lcm_instance def test_get_call_thread_pool_creates_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -269,8 +267,8 @@ def test_get_call_thread_pool_creates_pool(self) -> None: pool.shutdown(wait=False) def test_stop_shuts_down_thread_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 825e89fc8c..1b5ccadf3c 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -16,7 +16,7 @@ from abc import abstractmethod from collections import deque -from dataclasses import dataclass, field +from dataclasses import field from functools import reduce from typing import TypeVar @@ -25,23 +25,22 @@ from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.lcmservice import Service # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Service CONFIG = TypeVar("CONFIG") # generic configuration for transform service -@dataclass -class TFConfig: +class TFConfig(BaseConfig): buffer_size: float = 10.0 # seconds rate_limit: float = 10.0 # Hz -# generic specification for transform service -class TFSpec(Service[TFConfig]): - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(**kwargs) +_TFConfig = TypeVar("_TFConfig", bound=TFConfig) + +# generic specification for transform service +class TFSpec(Service[_TFConfig]): @abstractmethod def publish(self, *args: Transform) -> None: ... @@ -244,15 +243,17 @@ def __str__(self) -> str: return "\n".join(lines) -@dataclass class PubSubTFConfig(TFConfig): topic: Topic | None = None # Required field but needs default for dataclass inheritance pubsub: type[PubSub] | PubSub | None = None # type: ignore[type-arg] autostart: bool = True -class PubSubTF(MultiTBuffer, TFSpec): - default_config: type[PubSubTFConfig] = PubSubTFConfig +_PubSubConfig = TypeVar("_PubSubConfig", bound=PubSubTFConfig) + + +class PubSubTF(MultiTBuffer, TFSpec[_PubSubConfig]): + default_config: type[_PubSubConfig] = PubSubTFConfig # type: ignore[assignment] def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] TFSpec.__init__(self, **kwargs) @@ -330,15 +331,14 @@ def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) -@dataclass class LCMPubsubConfig(PubSubTFConfig): topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) pubsub: type[PubSub] | PubSub | None = LCM # type: ignore[type-arg] autostart: bool = True -class LCMTF(PubSubTF): - default_config: type[LCMPubsubConfig] = LCMPubsubConfig +class LCMTF(PubSubTF[LCMPubsubConfig]): + default_config = LCMPubsubConfig TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index 158a68d3d8..bf2885958d 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -13,15 +13,18 @@ # limitations under the License. from datetime import datetime -from typing import Union from dimos.msgs.geometry_msgs import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFConfig, TFSpec +class Config(TFConfig, LCMConfig): + """Combined config""" + + # this doesn't work due to tf_lcm_py package -class TFLCM(TFSpec, LCMService): +class TFLCM(TFSpec[Config], LCMService[Config]): """A service for managing and broadcasting transforms using LCM. This is not a separete module, You can include this in your module if you need to access transforms. @@ -34,7 +37,7 @@ class TFLCM(TFSpec, LCMService): for each module. """ - default_config = Union[TFConfig, LCMConfig] + default_config = Config def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index 7b44cea607..c606e7467e 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -26,7 +26,7 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.types import LatLon from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 @@ -45,9 +45,17 @@ def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> N composite.add(Disposable(item)) -class DroneConnectionModule(Module): +class Config(ModuleConfig): + connection_string: str = "udp:0.0.0.0:14550" + video_port: int = 5600 + outdoor: bool = False + + +class DroneConnectionModule(Module[Config]): """Module that handles drone sensor data and movement commands.""" + default_config = Config + # Inputs movecmd: In[Vector3] movecmd_twist: In[Twist] # Twist commands from tracking/navigation @@ -62,9 +70,6 @@ class DroneConnectionModule(Module): video: Out[Image] follow_object_cmd: Out[Any] - # Parameters - connection_string: str - # Internal state _odom: PoseStamped | None = None _status: dict[str, Any] = {} @@ -73,14 +78,7 @@ class DroneConnectionModule(Module): _latest_status: dict[str, Any] | None = None _latest_status_lock: threading.RLock - def __init__( - self, - connection_string: str = "udp:0.0.0.0:14550", - video_port: int = 5600, - outdoor: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize drone connection module. Args: @@ -88,9 +86,7 @@ def __init__( video_port: UDP port for video stream outdoor: Use GPS only mode (no velocity integration) """ - self.connection_string = connection_string - self.video_port = video_port - self.outdoor = outdoor + super().__init__(**kwargs) self.connection: MavlinkConnection | None = None self.video_stream: DJIDroneVideoStream | None = None self._latest_video_frame = None @@ -99,23 +95,24 @@ def __init__( self._latest_status_lock = threading.RLock() self._running = False self._telemetry_thread: threading.Thread | None = None - Module.__init__(self, *args, **kwargs) @rpc def start(self) -> None: """Start the connection and subscribe to sensor streams.""" # Check for replay mode - if self.connection_string == "replay": + if self.config.connection_string == "replay": from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection self.connection = FakeMavlinkConnection("replay") - self.video_stream = FakeDJIVideoStream(port=self.video_port) + self.video_stream = FakeDJIVideoStream(port=self.config.video_port) else: - self.connection = MavlinkConnection(self.connection_string, outdoor=self.outdoor) + self.connection = MavlinkConnection( + self.config.connection_string, outdoor=self.config.outdoor + ) self.connection.connect() - self.video_stream = DJIDroneVideoStream(port=self.video_port) + self.video_stream = DJIDroneVideoStream(port=self.config.video_port) if not self.connection.connected: logger.error("Failed to connect to drone") diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 78fbdaf168..9f0fc938e5 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -13,21 +13,21 @@ # limitations under the License. import asyncio +from collections.abc import Sequence import logging import threading -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from dimos_lcm.foxglove_bridge import ( FoxgloveBridge as LCMFoxgloveBridge, ) from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.core.global_config import GlobalConfig from dimos.core.rpc_client import ModuleProxy logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) @@ -36,31 +36,23 @@ logger = setup_logger() -class FoxgloveBridge(Module): +class FoxgloveConfig(ModuleConfig): + shm_channels: Sequence[str] = () + jpeg_shm_channels: Sequence[str] = () + + +class FoxgloveBridge(Module[FoxgloveConfig]): _thread: threading.Thread _loop: asyncio.AbstractEventLoop - _global_config: "GlobalConfig | None" = None - - def __init__( - self, - *args: Any, - shm_channels: list[str] | None = None, - jpeg_shm_channels: list[str] | None = None, - global_config: "GlobalConfig | None" = None, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.shm_channels = shm_channels or [] - self.jpeg_shm_channels = jpeg_shm_channels or [] - self._global_config = global_config + default_config = FoxgloveConfig @rpc def start(self) -> None: super().start() # Skip if Rerun is the selected viewer - if self._global_config and self._global_config.viewer.startswith("rerun"): - logger.info("Foxglove bridge skipped", viewer=self._global_config.viewer) + if self.config.g.viewer.startswith("rerun"): + logger.info("Foxglove bridge skipped", viewer=self.config.g.viewer) return def run_bridge() -> None: @@ -78,8 +70,8 @@ def run_bridge() -> None: port=8765, debug=False, num_threads=4, - shm_channels=self.shm_channels, - jpeg_shm_channels=self.jpeg_shm_channels, + shm_channels=self.config.shm_channels, + jpeg_shm_channels=self.config.jpeg_shm_channels, ) self._loop.run_until_complete(bridge.run()) except Exception as e: diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 4279f78399..445044020d 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -21,11 +21,12 @@ import socket import threading import time +from typing import Any from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry @@ -48,13 +49,21 @@ class RobotMode: RECOVERY = 6 -class B1ConnectionModule(Module): +class B1ConnectionConfig(ModuleConfig): + ip: str = "192.168.12.1" + port: int = 9090 + test_mode: bool = False + + +class B1ConnectionModule(Module[B1ConnectionConfig]): """UDP connection module for B1 robot with standard Twist interface. Accepts standard ROS Twist messages on /cmd_vel and mode changes on /b1/mode, internally converts to B1Command format, and sends UDP packets at 50Hz. """ + default_config = B1ConnectionConfig + # LCM ports (inter-module communication) cmd_vel: In[TwistStamped] mode_cmd: In[Int32] @@ -67,9 +76,7 @@ class B1ConnectionModule(Module): ros_odom_in: In[Odometry] ros_tf: In[TFMessage] - def __init__( # type: ignore[no-untyped-def] - self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize B1 connection module. Args: @@ -77,11 +84,11 @@ def __init__( # type: ignore[no-untyped-def] port: UDP port for joystick server test_mode: If True, print commands instead of sending UDP """ - Module.__init__(self, *args, **kwargs) + super().__init__(**kwargs) - self.ip = ip - self.port = port - self.test_mode = test_mode + self.ip = self.config.ip + self.port = self.config.port + self.test_mode = self.config.test_mode self.current_mode = RobotMode.IDLE # Start in IDLE mode self._current_cmd = B1Command(mode=RobotMode.IDLE) self.cmd_lock = threading.Lock() # Thread lock for _current_cmd access @@ -383,9 +390,10 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" - def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] """Initialize test connection without creating socket.""" - super().__init__(ip, port, test_mode=True, *args, **kwargs) # type: ignore[misc] + kwargs["test_mode"] = True + super().__init__(**kwargs) def _send_loop(self) -> None: """Override to provide better test output with timeout detection.""" diff --git a/dimos/robot/unitree/b1/joystick_module.py b/dimos/robot/unitree/b1/joystick_module.py index 0a72f81617..9fbfd84f1e 100644 --- a/dimos/robot/unitree/b1/joystick_module.py +++ b/dimos/robot/unitree/b1/joystick_module.py @@ -41,12 +41,9 @@ class JoystickModule(Module): twist_out: Out[TwistStamped] # Timestamped velocity commands mode_out: Out[Int32] # Mode changes - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - Module.__init__(self, *args, **kwargs) - self.pygame_ready = False - self.running = False - self.current_mode = 0 # Start in IDLE mode for safety + pygame_ready = False + running = False + current_mode = 0 # Start in IDLE mode for safety @rpc def start(self) -> None: diff --git a/dimos/robot/unitree/b1/unitree_b1.py b/dimos/robot/unitree/b1/unitree_b1.py index 2c0c918942..6b374d1d5b 100644 --- a/dimos/robot/unitree/b1/unitree_b1.py +++ b/dimos/robot/unitree/b1/unitree_b1.py @@ -92,9 +92,9 @@ def start(self) -> None: logger.info("Deploying connection module...") if self.test_mode: - self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(MockB1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] else: - self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(B1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] # Configure LCM transports for connection (matching G1 pattern) self.connection.cmd_vel.transport = LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index c2dbc6ab2d..94f725ac7e 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -14,14 +14,14 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar +from pydantic import Field from reactivex.disposable import Disposable from dimos import spec from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In from dimos.msgs.geometry_msgs import Twist @@ -32,9 +32,15 @@ from dimos.core.rpc_client import ModuleProxy logger = setup_logger() +_Config = TypeVar("_Config", bound=ModuleConfig) -class G1ConnectionBase(Module, ABC): +class G1Config(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + connection_type: str = Field(default_factory=lambda m: m["g"].unitree_connection_type) + + +class G1ConnectionBase(Module[_Config], ABC): """Abstract base for G1 connections (real hardware and simulation). Modules that depend on G1 connection RPC methods should reference this @@ -61,36 +67,19 @@ def move(self, twist: Twist, duration: float = 0.0) -> None: ... def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: ... -class G1Connection(G1ConnectionBase): +class G1Connection(G1ConnectionBase[G1Config]): + default_config = G1Config + cmd_vel: In[Twist] - ip: str | None - connection_type: str | None = None - _global_config: GlobalConfig - - connection: UnitreeWebRTCConnection | None - - def __init__( - self, - ip: str | None = None, - connection_type: str | None = None, - cfg: GlobalConfig = global_config, - *args: Any, - **kwargs: Any, - ) -> None: - self._global_config = cfg - self.ip = ip if ip is not None else self._global_config.robot_ip - self.connection_type = connection_type or self._global_config.unitree_connection_type - self.connection = None - super().__init__(*args, **kwargs) + connection: UnitreeWebRTCConnection | None = None @rpc def start(self) -> None: super().start() - match self.connection_type: + match self.config.connection_type: case "webrtc": - assert self.ip is not None, "IP address must be provided" - self.connection = UnitreeWebRTCConnection(self.ip) + self.connection = UnitreeWebRTCConnection(self.config.ip) case "replay": raise ValueError("Replay connection not implemented for G1 robot") case "mujoco": @@ -98,7 +87,7 @@ def start(self) -> None: "This module does not support simulation, use G1SimConnection instead" ) case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") + raise ValueError(f"Unknown connection type: {self.config.connection_type}") assert self.connection is not None self.connection.start() @@ -127,7 +116,7 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: def deploy(dimos: ModuleCoordinator, ip: str, local_planner: spec.LocalPlanner) -> "ModuleProxy": - connection = dimos.deploy(G1Connection, ip) # type: ignore[attr-defined] + connection = dimos.deploy(G1Connection, ip=ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index 06950c6f0d..9226bb4e7f 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -16,12 +16,13 @@ import threading from threading import Thread import time -from typing import TYPE_CHECKING, Any +from typing import Any +from pydantic import Field from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import ( PoseStamped, @@ -32,37 +33,31 @@ ) from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 from dimos.robot.unitree.g1.connection import G1ConnectionBase +from dimos.robot.unitree.mujoco_connection import MujocoConnection from dimos.robot.unitree.type.odometry import Odometry as SimOdometry from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.robot.unitree.mujoco_connection import MujocoConnection - logger = setup_logger() -class G1SimConnection(G1ConnectionBase): +class G1SimConfig(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + + +class G1SimConnection(G1ConnectionBase[G1SimConfig]): + default_config = G1SimConfig + cmd_vel: In[Twist] lidar: Out[PointCloud2] odom: Out[PoseStamped] color_image: Out[Image] camera_info: Out[CameraInfo] - ip: str | None - _global_config: GlobalConfig + connection: MujocoConnection | None = None _camera_info_thread: Thread | None = None - def __init__( - self, - ip: str | None = None, - cfg: GlobalConfig = global_config, - *args: Any, - **kwargs: Any, - ) -> None: - self._global_config = cfg - self.ip = ip if ip is not None else self._global_config.robot_ip - self.connection: MujocoConnection | None = None + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() - super().__init__(*args, **kwargs) @rpc def start(self) -> None: @@ -70,7 +65,7 @@ def start(self) -> None: from dimos.robot.unitree.mujoco_connection import MujocoConnection - self.connection = MujocoConnection(self._global_config) + self.connection = MujocoConnection(self.config.g) assert self.connection is not None self.connection.start() diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index afd5c25ed6..c06028ec6f 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -13,10 +13,12 @@ # limitations under the License. import logging +import sys from threading import Thread import time from typing import TYPE_CHECKING, Any, Protocol +from pydantic import Field from reactivex.disposable import Disposable from reactivex.observable import Observable import rerun.blueprint as rrb @@ -24,8 +26,8 @@ from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, pSHMTransport @@ -46,9 +48,18 @@ from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.testing.replay import TimedSensorReplay, TimedSensorStorage +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = logging.getLogger(__name__) +class ConnectionConfig(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + + class Go2ConnectionProtocol(Protocol): """Protocol defining the interface for Go2 robot connections.""" @@ -170,7 +181,12 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de return {"status": "ok", "message": "Fake publish"} -class GO2Connection(Module, spec.Camera, spec.Pointcloud): +_Config = TypeVar("_Config", bound=ConnectionConfig, default=ConnectionConfig) + + +class GO2Connection(Module[_Config], spec.Camera, spec.Pointcloud): + default_config = ConnectionConfig # type: ignore[assignment] + cmd_vel: In[Twist] pointcloud: Out[PointCloud2] odom: Out[PoseStamped] @@ -180,7 +196,6 @@ class GO2Connection(Module, spec.Camera, spec.Pointcloud): connection: Go2ConnectionProtocol camera_info_static: CameraInfo = _camera_info_static() - _global_config: GlobalConfig _camera_info_thread: Thread | None = None _latest_video_frame: Image | None = None @@ -194,23 +209,13 @@ def rerun_views(cls): # type: ignore[no-untyped-def] ), ] - def __init__( # type: ignore[no-untyped-def] - self, - ip: str | None = None, - cfg: GlobalConfig = global_config, - *args, - **kwargs, - ) -> None: - self._global_config = cfg - - ip = ip if ip is not None else self._global_config.robot_ip - self.connection = make_connection(ip, self._global_config) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.connection = make_connection(self.config.ip, self.config.g) if hasattr(self.connection, "camera_info_static"): self.camera_info_static = self.connection.camera_info_static - Module.__init__(self, *args, **kwargs) - @rpc def record(self, recording_name: str) -> None: lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") # type: ignore[type-arg] @@ -246,7 +251,7 @@ def onimage(image: Image) -> None: self.standup() time.sleep(3) self.connection.balance_stand() - self.connection.set_obstacle_avoidance(self._global_config.obstacle_avoidance) + self.connection.set_obstacle_avoidance(self.config.g.obstacle_avoidance) # self.record("go2_bigoffice") @@ -339,7 +344,7 @@ def observe(self) -> Image | None: def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy": from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE - connection = dimos.deploy(GO2Connection, ip) # type: ignore[attr-defined] + connection = dimos.deploy(GO2Connection, ip=ip) connection.pointcloud.transport = pSHMTransport( f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE diff --git a/dimos/robot/unitree/go2/fleet_connection.py b/dimos/robot/unitree/go2/fleet_connection.py index 4dd2be2984..24a95ec4d2 100644 --- a/dimos/robot/unitree/go2/fleet_connection.py +++ b/dimos/robot/unitree/go2/fleet_connection.py @@ -16,52 +16,62 @@ from __future__ import annotations +from collections.abc import Sequence +import sys from typing import TYPE_CHECKING, Any +from pydantic import Field, model_validator + from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.robot.unitree.go2.connection import ( + ConnectionConfig, GO2Connection, Go2ConnectionProtocol, make_connection, ) from dimos.utils.logging_config import setup_logger +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + if TYPE_CHECKING: from dimos.msgs.geometry_msgs import Twist logger = setup_logger() -class Go2FleetConnection(GO2Connection): +class FleetConnectionConfig(ConnectionConfig): + ips: Sequence[str] = Field( + default_factory=lambda m: [ip.strip() for ip in m["g"].robot_ips.split(",")] + ) + + @model_validator(mode="after") + def set_ip_after_validation(self) -> Self: + if self.ip is None: + self.ip = self.ips[0] + return self + + +class Go2FleetConnection(GO2Connection[FleetConnectionConfig]): """Inherits all single-robot behaviour from GO2Connection for the primary (first) robot. Additional robots only receive broadcast commands (move, standup, liedown, publish_request). """ - def __init__( - self, - ips: list[str] | None = None, - cfg: GlobalConfig = global_config, - *args: object, - **kwargs: object, - ) -> None: - if not ips: - raw = cfg.robot_ips - if not raw: - raise ValueError( - "No IPs provided. Pass ips= or set ROBOT_IPS (e.g. ROBOT_IPS=10.0.0.102,10.0.0.209)" - ) - ips = [ip.strip() for ip in raw.split(",") if ip.strip()] - self._extra_ips = ips[1:] + default_config = FleetConnectionConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._extra_ips = self.config.ips[1:] self._extra_connections: list[Go2ConnectionProtocol] = [] - super().__init__(ips[0], cfg, *args, **kwargs) @rpc def start(self) -> None: self._extra_connections.clear() for ip in self._extra_ips: - conn = make_connection(ip, self._global_config) + conn = make_connection(ip, self.config.g) conn.start() self._extra_connections.append(conn) @@ -69,7 +79,7 @@ def start(self) -> None: super().start() for conn in self._extra_connections: conn.balance_stand() - conn.set_obstacle_avoidance(self._global_config.obstacle_avoidance) + conn.set_obstacle_avoidance(self.config.g.obstacle_avoidance) @rpc def stop(self) -> None: diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 14be8432e5..3cd03df785 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -15,6 +15,7 @@ import os import threading +from typing import Any import pygame @@ -42,8 +43,8 @@ class KeyboardTeleop(Module): _clock: pygame.time.Clock | None = None _font: pygame.font.Font | None = None - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() @rpc diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 36673ecb3e..3bc4e075f7 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -126,6 +126,7 @@ def start(self) -> None: self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], + stderr=subprocess.PIPE, ) except Exception as e: diff --git a/dimos/robot/unitree/rosnav.py b/dimos/robot/unitree/rosnav.py index adc97eb4a2..083c7413fe 100644 --- a/dimos/robot/unitree/rosnav.py +++ b/dimos/robot/unitree/rosnav.py @@ -33,11 +33,7 @@ class NavigationModule(Module): goal_reached: In[Bool] cancel_goal: Out[Bool] joy: Out[Joy] - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - """Initialize NavigationModule.""" - Module.__init__(self, *args, **kwargs) - self.goal_reach = None + goal_reach = None @rpc def start(self) -> None: diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index 95b2bf6f6b..274115d516 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -21,8 +21,7 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport @@ -34,39 +33,35 @@ from dimos.robot.unitree.go2.connection import Go2ConnectionProtocol -class Map(Module): +class MapConfig(ModuleConfig): + voxel_size: float = 0.05 + cost_resolution: float = 0.05 + global_publish_interval: float | None = None + min_height: float = 0.10 + max_height: float = 0.5 + + +class Map(Module[MapConfig]): + default_config = MapConfig + lidar: In[PointCloud2] global_map: Out[PointCloud2] global_costmap: Out[OccupancyGrid] _point_cloud_accumulator: PointCloudAccumulator - _global_config: GlobalConfig _preloaded_occupancy: OccupancyGrid | None = None - def __init__( # type: ignore[no-untyped-def] - self, - voxel_size: float = 0.05, - cost_resolution: float = 0.05, - global_publish_interval: float | None = None, - min_height: float = 0.10, - max_height: float = 0.5, - cfg: GlobalConfig = global_config, - **kwargs, - ) -> None: - self.voxel_size = voxel_size - self.cost_resolution = cost_resolution - self.global_publish_interval = global_publish_interval - self.min_height = min_height - self.max_height = max_height - self._global_config = cfg - self._point_cloud_accumulator = GeneralPointCloudAccumulator( - self.voxel_size, self._global_config - ) - - if self._global_config.simulation: - self.min_height = 0.3 - + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + self.voxel_size = self.config.voxel_size + self.cost_resolution = self.config.cost_resolution + self.global_publish_interval = self.config.global_publish_interval + self.min_height = self.config.min_height + self.max_height = self.config.max_height + self._point_cloud_accumulator = GeneralPointCloudAccumulator(self.voxel_size, self.config.g) + + if self.config.g.simulation: + self.min_height = 0.3 @rpc def start(self) -> None: @@ -108,9 +103,9 @@ def _publish(self, _: Any) -> None: ) # When debugging occupancy navigation, load a predefined occupancy grid. - if self._global_config.mujoco_global_costmap_from_occupancy: + if self.config.g.mujoco_global_costmap_from_occupancy: if self._preloaded_occupancy is None: - path = Path(self._global_config.mujoco_global_costmap_from_occupancy) + path = Path(self.config.g.mujoco_global_costmap_from_occupancy) self._preloaded_occupancy = OccupancyGrid.from_path(path) occupancygrid = self._preloaded_occupancy diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 831ea6ee34..20a55f1d02 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -15,7 +15,6 @@ """Simulator-agnostic manipulator simulation module.""" from collections.abc import Callable -from dataclasses import dataclass from pathlib import Path import threading import time @@ -31,7 +30,6 @@ from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface -@dataclass(kw_only=True) class SimulationModuleConfig(ModuleConfig): engine: EngineType config_path: Path | Callable[[], Path] @@ -42,7 +40,6 @@ class SimulationModule(Module[SimulationModuleConfig]): """Module wrapper for manipulator simulation across engines.""" default_config = SimulationModuleConfig - config: SimulationModuleConfig joint_state: Out[JointState] robot_state: Out[RobotState] @@ -51,8 +48,8 @@ class SimulationModule(Module[SimulationModuleConfig]): MIN_CONTROL_RATE = 1.0 - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._backend: SimManipInterface | None = None self._control_rate = 100.0 self._monitor_rate = 100.0 diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 334e2ce85f..72408fefed 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -17,10 +17,11 @@ import pytest +from dimos.protocol.rpc import RPCSpec from dimos.simulation.manipulators.sim_module import SimulationModule -class _DummyRPC: +class _DummyRPC(RPCSpec): def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] return None diff --git a/dimos/teleop/keyboard/keyboard_teleop_module.py b/dimos/teleop/keyboard/keyboard_teleop_module.py index cc3c301804..854c0fbc22 100644 --- a/dimos/teleop/keyboard/keyboard_teleop_module.py +++ b/dimos/teleop/keyboard/keyboard_teleop_module.py @@ -28,7 +28,6 @@ ESC: Quit """ -from dataclasses import dataclass import os import threading import time @@ -64,7 +63,6 @@ def _clamp(value: float, min_val: float, max_val: float) -> float: return max(min_val, min(max_val, value)) -@dataclass class KeyboardTeleopConfig(ModuleConfig): model_path: str = "" ee_joint_id: int = 6 @@ -84,8 +82,8 @@ class KeyboardTeleopModule(Module[KeyboardTeleopConfig]): _stop_event: threading.Event _thread: threading.Thread | None = None - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() @rpc diff --git a/dimos/teleop/phone/phone_teleop_module.py b/dimos/teleop/phone/phone_teleop_module.py index 4d40b995f3..f13842811b 100644 --- a/dimos/teleop/phone/phone_teleop_module.py +++ b/dimos/teleop/phone/phone_teleop_module.py @@ -22,7 +22,6 @@ velocity commands via configurable gains, and publishes. """ -from dataclasses import dataclass from pathlib import Path import threading import time @@ -48,7 +47,6 @@ STATIC_DIR = Path(__file__).parent / "web" / "static" -@dataclass class PhoneTeleopConfig(ModuleConfig): control_loop_hz: float = 50.0 linear_gain: float = 1.0 / 30.0 # Gain: maps degrees of tilt to m/s. 30 deg -> 1.0 m/s @@ -75,8 +73,8 @@ class PhoneTeleopModule(Module[PhoneTeleopConfig]): # Initialization # ------------------------------------------------------------------------- - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._is_engaged: bool = False self._teleop_button: bool = False diff --git a/dimos/teleop/quest/quest_extensions.py b/dimos/teleop/quest/quest_extensions.py index 68ec279efb..c92ac55a43 100644 --- a/dimos/teleop/quest/quest_extensions.py +++ b/dimos/teleop/quest/quest_extensions.py @@ -20,9 +20,10 @@ - VisualizingTeleopModule: Adds Rerun visualization (inherits press-and-hold engage) """ -from dataclasses import dataclass, field from typing import Any +from pydantic import Field + from dimos.core.stream import Out from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped from dimos.teleop.quest.quest_teleop_module import Hand, QuestTeleopConfig, QuestTeleopModule @@ -33,7 +34,6 @@ ) -@dataclass class TwistTeleopConfig(QuestTeleopConfig): """Configuration for TwistTeleopModule.""" @@ -42,7 +42,7 @@ class TwistTeleopConfig(QuestTeleopConfig): # Example implementation to show how to extend QuestTeleopModule for different teleop behaviors and outputs. -class TwistTeleopModule(QuestTeleopModule): +class TwistTeleopModule(QuestTeleopModule[TwistTeleopConfig]): """Quest teleop that outputs TwistStamped instead of PoseStamped. Config: @@ -56,7 +56,6 @@ class TwistTeleopModule(QuestTeleopModule): """ default_config = TwistTeleopConfig - config: TwistTeleopConfig left_twist: Out[TwistStamped] right_twist: Out[TwistStamped] @@ -75,7 +74,6 @@ def _publish_msg(self, hand: Hand, output_msg: PoseStamped) -> None: self.right_twist.publish(twist) -@dataclass class ArmTeleopConfig(QuestTeleopConfig): """Configuration for ArmTeleopModule. @@ -85,10 +83,10 @@ class ArmTeleopConfig(QuestTeleopConfig): hand's commands to the correct TeleopIKTask. """ - task_names: dict[str, str] = field(default_factory=dict) + task_names: dict[str, str] = Field(default_factory=dict) -class ArmTeleopModule(QuestTeleopModule): +class ArmTeleopModule(QuestTeleopModule[ArmTeleopConfig]): """Quest teleop with per-hand press-and-hold engage and task name routing. Each controller's primary button (X for left, A for right) @@ -105,10 +103,9 @@ class ArmTeleopModule(QuestTeleopModule): """ default_config = ArmTeleopConfig - config: ArmTeleopConfig - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._task_names: dict[Hand, str] = { Hand[k.upper()]: v for k, v in self.config.task_names.items() diff --git a/dimos/teleop/quest/quest_teleop_module.py b/dimos/teleop/quest/quest_teleop_module.py index f862558424..9beaf0da3e 100644 --- a/dimos/teleop/quest/quest_teleop_module.py +++ b/dimos/teleop/quest/quest_teleop_module.py @@ -26,7 +26,7 @@ from pathlib import Path import threading import time -from typing import Any +from typing import Any, TypeVar from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped from dimos_lcm.sensor_msgs import Joy as LCMJoy @@ -68,7 +68,6 @@ class QuestTeleopStatus: buttons: Buttons -@dataclass class QuestTeleopConfig(ModuleConfig): """Configuration for Quest Teleoperation Module.""" @@ -76,7 +75,10 @@ class QuestTeleopConfig(ModuleConfig): server_port: int = 8443 -class QuestTeleopModule(Module[QuestTeleopConfig]): +_Config = TypeVar("_Config", bound=QuestTeleopConfig) + + +class QuestTeleopModule(Module[_Config]): """Quest Teleoperation Module for Meta Quest controllers. Receives controller data from the Quest web app via an embedded WebSocket @@ -89,7 +91,7 @@ class QuestTeleopModule(Module[QuestTeleopConfig]): - buttons: Buttons (button states for both controllers) """ - default_config = QuestTeleopConfig + default_config = QuestTeleopConfig # type: ignore[assignment] # Outputs: delta poses for each controller left_controller_output: Out[PoseStamped] @@ -100,8 +102,8 @@ class QuestTeleopModule(Module[QuestTeleopConfig]): # Initialization # ------------------------------------------------------------------------- - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # Engage state (per-hand) self._is_engaged: dict[Hand, bool] = {Hand.LEFT: False, Hand.RIGHT: False} diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py index 651e8d551b..5b2d0be4ef 100755 --- a/dimos/utils/cli/lcmspy/lcmspy.py +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -13,9 +13,9 @@ # limitations under the License. from collections import deque -from dataclasses import dataclass import threading import time +from typing import Any from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.utils.human import human_bytes @@ -98,20 +98,19 @@ def __str__(self) -> str: return f"topic({self.name})" -@dataclass class LCMSpyConfig(LCMConfig): topic_history_window: float = 60.0 -class LCMSpy(LCMService, Topic): +class LCMSpy(LCMService[LCMSpyConfig], Topic): default_config = LCMSpyConfig topic = dict[str, Topic] graph_log_window: float = 1.0 topic_class: type[Topic] = Topic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - Topic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + Topic.__init__(self, name="total", history_window=self.config.topic_history_window) self.topic = {} # type: ignore[assignment] self._topic_lock = threading.Lock() @@ -150,7 +149,6 @@ def update_graphs(self, step_window: float = 1.0) -> None: self.bandwidth_history.append(kbps) -@dataclass class GraphLCMSpyConfig(LCMSpyConfig): graph_log_window: float = 1.0 @@ -162,9 +160,9 @@ class GraphLCMSpy(LCMSpy, GraphTopic): graph_log_stop_event: threading.Event = threading.Event() topic_class: type[Topic] = GraphTopic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) def start(self) -> None: super().start() diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 9bba9dd82f..6729f143cd 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -16,11 +16,11 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import field from functools import lru_cache import time from typing import ( - TYPE_CHECKING, Any, Literal, Protocol, @@ -31,6 +31,8 @@ ) from reactivex.disposable import Disposable +from rerun._baseclasses import Archetype +from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] import typer @@ -39,6 +41,7 @@ from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches +from dimos.protocol.pubsub.spec import SubscribeAllCapable from dimos.utils.logging_config import setup_logger # Message types with large payloads that need rate-limiting. @@ -96,15 +99,7 @@ logger = setup_logger() -if TYPE_CHECKING: - from collections.abc import Callable - - from rerun._baseclasses import Archetype - from rerun.blueprint import Blueprint - - from dimos.protocol.pubsub.spec import SubscribeAllCapable - -BlueprintFactory: TypeAlias = "Callable[[], Blueprint]" +BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] # to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" @@ -113,8 +108,6 @@ def is_rerun_multi(data: Any) -> TypeGuard[RerunMulti]: """Check if data is a list of (entity_path, archetype) tuples.""" - from rerun._baseclasses import Archetype - return ( isinstance(data, list) and bool(data) @@ -167,7 +160,6 @@ def _resolve_viewer_mode() -> ViewerMode: return _BACKEND_TO_MODE.get(global_config.viewer, "native") -@dataclass class Config(ModuleConfig): """Configuration for RerunBridgeModule.""" @@ -190,7 +182,7 @@ class Config(ModuleConfig): blueprint: BlueprintFactory | None = _default_blueprint -class RerunBridgeModule(Module): +class RerunBridgeModule(Module[Config]): """Bridge that logs messages from pubsubs to Rerun. Spawns its own Rerun viewer and subscribes to all topics on each provided @@ -207,7 +199,6 @@ class RerunBridgeModule(Module): """ default_config = Config - config: Config @lru_cache(maxsize=256) def _visual_override_for_entity_path( @@ -218,8 +209,6 @@ def _visual_override_for_entity_path( Chains matching overrides from config, ending with final_convert which handles .to_rerun() or passes through Archetypes. """ - from rerun._baseclasses import Archetype - # find all matching converters for this entity path matches = [ fn diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 0304e3b77b..7a5c9587e1 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -46,8 +46,7 @@ ) from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.inflation import simple_inflate @@ -64,7 +63,11 @@ _browser_opened = False -class WebsocketVisModule(Module): +class WebsocketConfig(ModuleConfig): + port: int = 7779 + + +class WebsocketVisModule(Module[WebsocketConfig]): """ WebSocket-based visualization module for real-time navigation data. @@ -83,6 +86,8 @@ class WebsocketVisModule(Module): - click_goal: Goal position from user clicks """ + default_config = WebsocketConfig + # LCM inputs odom: In[PoseStamped] gps_location: In[LatLon] @@ -97,12 +102,7 @@ class WebsocketVisModule(Module): cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] - def __init__( - self, - port: int = 7779, - cfg: GlobalConfig = global_config, - **kwargs: Any, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the WebSocket visualization module. Args: @@ -110,9 +110,6 @@ def __init__( cfg: Optional global config for viewer settings """ super().__init__(**kwargs) - self._global_config = cfg - - self.port = port self._uvicorn_server_thread: threading.Thread | None = None self.sio: socketio.AsyncServer | None = None self.app = None @@ -127,7 +124,7 @@ def __init__( # Track GPS goal points for visualization self.gps_goal_points: list[dict[str, float]] = [] logger.info( - f"WebSocket visualization module initialized on port {port}, GPS goal tracking enabled" + f"WebSocket visualization module initialized on port {self.config.port}, GPS goal tracking enabled" ) def _start_broadcast_loop(self) -> None: @@ -157,8 +154,8 @@ def start(self) -> None: # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) # For rerun and foxglove, users access the command center manually if needed - if self._global_config.viewer == "rerun-web": - url = f"http://localhost:{self.port}/" + if self.config.g.viewer == "rerun-web": + url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") global _browser_opened @@ -234,7 +231,7 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" # If running native Rerun, redirect to standalone command center - if self._global_config.viewer != "rerun-web": + if self.config.g.viewer != "rerun-web": return RedirectResponse(url="/command-center") # Otherwise serve full dashboard with Rerun iframe @@ -355,7 +352,7 @@ def _run_uvicorn_server(self) -> None: config = uvicorn.Config( self.app, # type: ignore[arg-type] host="0.0.0.0", - port=self.port, + port=self.config.port, log_level="error", # Reduce verbosity ) self._uvicorn_server = uvicorn.Server(config) diff --git a/docs/usage/blueprints.md b/docs/usage/blueprints.md index ed48670cb4..80a6b24b19 100644 --- a/docs/usage/blueprints.md +++ b/docs/usage/blueprints.md @@ -9,13 +9,16 @@ You create a `Blueprint` from a single module (say `ConnectionModule`) with: ```python session=blueprint-ex1 from dimos.core.blueprints import Blueprint from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig -class ConnectionModule(Module): - def __init__(self, arg1, arg2, kwarg='value') -> None: - super().__init__() +class ConnectionConfig(ModuleConfig): + arg1: int + arg2: str = "value" + +class ConnectionModule(Module[ConnectionConfig]): + default_config = ConnectionConfig -blueprint = Blueprint.create(ConnectionModule, 'arg1', 'arg2', kwarg='value') +blueprint = Blueprint.create(ConnectionModule, arg1=5, arg2="foo") ``` But the same thing can be accomplished more succinctly as: @@ -37,9 +40,11 @@ You can link multiple blueprints together with `autoconnect`: ```python session=blueprint-ex1 from dimos.core.blueprints import autoconnect -class Module1(Module): - def __init__(self, arg1) -> None: - super().__init__() +class Config(ModuleConfig): + arg1: int = 42 + +class Module1(Module[Config]): + default_config = Config class Module2(Module): ... @@ -206,7 +211,7 @@ blueprint.remappings([ ## Overriding global configuration. -Each module can optionally take global config as a `cfg` option in `__init__`. E.g.: +Each module includes the global config available as `self.config.g`. E.g.: ```python session=blueprint-ex3 from dimos.core.core import rpc @@ -214,9 +219,8 @@ from dimos.core.module import Module from dimos.core.global_config import GlobalConfig class ModuleA(Module): - - def __init__(self, cfg: GlobalConfig | None = None): - self._global_config: GlobalConfig = cfg + def some_method(self): + print(self.config.g.viewer) ... ``` diff --git a/docs/usage/configuration.md b/docs/usage/configuration.md index fe6e0029f0..384ef5240e 100644 --- a/docs/usage/configuration.md +++ b/docs/usage/configuration.md @@ -2,23 +2,19 @@ Dimos provides a `Configurable` base class. See [`service/spec.py`](/dimos/protocol/service/spec.py#L22). -This allows using dataclasses to specify configuration structure and default values per module. +This allows using pydantic models to specify configuration structure and default values per module. ```python from dimos.protocol.service import Configurable +from dimos.protocol.service.spec import BaseConfig from rich import print -from dataclasses import dataclass -@dataclass -class Config(): +class Config(BaseConfig): x: int = 3 hello: str = "world" -class MyClass(Configurable): +class MyClass(Configurable[Config]): default_config = Config - config: Config - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) myclass1 = MyClass() print(myclass1.config) @@ -48,22 +44,19 @@ Error: Config.__init__() got an unexpected keyword argument 'something' [Modules](/docs/usage/modules.md) inherit from `Configurable`, so all of the above applies. Module configs should inherit from `ModuleConfig` ([`core/module.py`](/dimos/core/module.py#L40)), which includes shared configuration for all modules like transport protocols, frame IDs, etc. ```python -from dataclasses import dataclass from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from rich import print -@dataclass class Config(ModuleConfig): frame_id: str = "world" publish_interval: float = 0 voxel_size: float = 0.05 device: str = "CUDA:0" -class MyModule(Module): +class MyModule(Module[Config]): default_config = Config - config: Config def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/docs/usage/native_modules.md b/docs/usage/native_modules.md index 929ac18424..de12417b4a 100644 --- a/docs/usage/native_modules.md +++ b/docs/usage/native_modules.md @@ -17,7 +17,6 @@ Python side native module is just a definition of a **config** dataclass and **m Both the config dataclass and pubsub topics get converted to CLI args passed down to your executable once the module is started. ```python no-result session=nativemodule -from dataclasses import dataclass from dimos.core.stream import Out from dimos.core.transport import LCMTransport from dimos.core.native_module import NativeModule, NativeModuleConfig @@ -25,13 +24,12 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.sensor_msgs.Imu import Imu import time -@dataclass(kw_only=True) class MyLidarConfig(NativeModuleConfig): executable: str = "./build/my_lidar" host_ip: str = "192.168.1.5" frequency: float = 10.0 -class MyLidar(NativeModule): +class MyLidar(NativeModule[MyLidarConfig]): default_config = MyLidarConfig pointcloud: Out[PointCloud2] imu: Out[Imu] @@ -98,18 +96,18 @@ When `stop()` is called, the process receives SIGTERM. If it doesn't exit within Any field you add to your config subclass automatically becomes a `--name value` CLI arg. Fields from `NativeModuleConfig` itself (like `executable`, `extra_args`, `cwd`) are **not** passed — they're for Python-side orchestration only. ```python skip +from pydantic import Field class LogFormat(enum.Enum): TEXT = "text" JSON = "json" -@dataclass(kw_only=True) class MyConfig(NativeModuleConfig): executable: str = "./build/my_module" # relative or absolute path to your executable host_ip: str = "192.168.1.5" # becomes --host_ip 192.168.1.5 frequency: float = 10.0 # becomes --frequency 10.0 enable_imu: bool = True # becomes --enable_imu true - filters: list[str] = field(default_factory=lambda: ["a", "b"]) # becomes --filters a,b + filters: list[str] = Field(default_factory=lambda: ["a", "b"]) # becomes --filters a,b ``` - `None` values are skipped. @@ -121,16 +119,11 @@ class MyConfig(NativeModuleConfig): If a config field shouldn't be a CLI arg, add it to `cli_exclude`: ```python skip -@dataclass(kw_only=True) class FastLio2Config(NativeModuleConfig): executable: str = "./build/fastlio2" config: str = "mid360.yaml" # human-friendly name - config_path: str | None = None # resolved absolute path + config_path: str = Field(default_factory=lambda m: str(Path(m["config"]).resolve())) cli_exclude: frozenset[str] = frozenset({"config"}) # only config_path is passed - - def __post_init__(self) -> None: - if self.config_path is None: - self.config_path = str(Path(self.config).resolve()) ``` ## Using with blueprints @@ -173,7 +166,6 @@ NativeModule pipes subprocess stdout and stderr through structlog: If your native binary outputs structured JSON lines, set `log_format=LogFormat.JSON`: ```python skip -@dataclass(kw_only=True) class MyConfig(NativeModuleConfig): executable: str = "./build/my_module" log_format: LogFormat = LogFormat.JSON @@ -236,7 +228,6 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.sensor_msgs.Imu import Imu from dimos.spec import perception -@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): cwd: str | None = "cpp" executable: str = "result/bin/mid360_native" @@ -248,7 +239,7 @@ class Mid360Config(NativeModuleConfig): frame_id: str = "lidar_link" # ... SDK port configuration -class Mid360(NativeModule, perception.Lidar, perception.IMU): +class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): default_config = Mid360Config lidar: Out[PointCloud2] imu: Out[Imu] @@ -271,7 +262,6 @@ If `build_command` is set in the module config, and the executable doesn't exist Build output is piped through structlog (stdout at `info`, stderr at `warning`). ```python skip -@dataclass(kw_only=True) class MyLidarConfig(NativeModuleConfig): cwd: str | None = "cpp" executable: str = "result/bin/my_lidar" diff --git a/docs/usage/transforms.md b/docs/usage/transforms.md index 8b98e4e81d..8a3f708cd2 100644 --- a/docs/usage/transforms.md +++ b/docs/usage/transforms.md @@ -173,9 +173,7 @@ Modules in DimOS automatically get a `frame_id` property. This is controlled by ```python from dimos.core.module import Module, ModuleConfig -from dataclasses import dataclass -@dataclass class MyModuleConfig(ModuleConfig): frame_id: str = "sensor_link" frame_id_prefix: str | None = None @@ -228,8 +226,6 @@ from dimos.core.module_coordinator import ModuleCoordinator class RobotBaseModule(Module): """Publishes the robot's position in the world frame at 10Hz.""" - def __init__(self, **kwargs: object) -> None: - super().__init__(**kwargs) @rpc def start(self) -> None: diff --git a/examples/simplerobot/simplerobot.py b/examples/simplerobot/simplerobot.py index 010b3bf2eb..2a1867b37c 100644 --- a/examples/simplerobot/simplerobot.py +++ b/examples/simplerobot/simplerobot.py @@ -22,10 +22,8 @@ Subscribes to Twist commands and publishes PoseStamped. """ -from dataclasses import dataclass import math import time -from typing import Any import reactivex as rx @@ -48,7 +46,6 @@ def apply_twist(pose: Pose, twist: Twist, dt: float) -> Pose: ) -@dataclass class SimpleRobotConfig(ModuleConfig): frame_id: str = "world" update_rate: float = 30.0 @@ -61,12 +58,9 @@ class SimpleRobot(Module[SimpleRobotConfig]): cmd_vel: In[Twist] pose: Out[PoseStamped] default_config = SimpleRobotConfig - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._pose = Pose() - self._vel = Twist() - self._vel_time = 0.0 + _pose = Pose() + _vel = Twist() + _vel_time = 0.0 @rpc def start(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 017562a78a..4370944b27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -354,8 +354,12 @@ exclude = [ [tool.ruff.lint] extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] -# TODO: All of these should be fixed, but it's easier commit autofixes first -ignore = ["A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] +ignore = [ + # TODO: All of these should be fixed, but it's easier commit autofixes first + "A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007", + # This breaks runtime type checking (both for us, and users introspecting our APIs) + "TC001", "TC002", "TC003" +] [tool.ruff.lint.per-file-ignores] "dimos/models/Detic/*" = ["ALL"] From 8ff7377f40f3ddbaba1269accece39dabe6ca630 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Sat, 14 Mar 2026 06:05:13 +0200 Subject: [PATCH 08/42] chore(comments): remove section markers (#1546) --- dimos/agents/mcp/mcp_adapter.py | 20 --- dimos/agents/mcp/mcp_server.py | 19 --- dimos/agents_deprecated/agent.py | 16 -- .../agents_deprecated/prompt_builder/impl.py | 2 - dimos/control/blueprints.py | 37 ----- dimos/control/coordinator.py | 34 ----- dimos/control/task.py | 9 -- dimos/control/tasks/cartesian_ik_task.py | 4 - dimos/control/tasks/servo_task.py | 4 - dimos/control/tasks/teleop_task.py | 4 - dimos/control/tasks/trajectory_task.py | 4 - dimos/control/tasks/velocity_task.py | 4 - dimos/control/test_control.py | 39 ----- dimos/control/tick_loop.py | 7 - dimos/core/daemon.py | 14 -- dimos/core/test_cli_stop_status.py | 10 -- dimos/core/test_daemon.py | 21 --- dimos/core/test_e2e_daemon.py | 19 --- dimos/core/test_mcp_integration.py | 15 -- dimos/core/tests/demo_devex.py | 33 ---- .../hardware/drive_trains/flowbase/adapter.py | 24 --- dimos/hardware/drive_trains/mock/adapter.py | 24 --- dimos/hardware/drive_trains/spec.py | 10 -- dimos/hardware/manipulators/mock/adapter.py | 40 ----- dimos/hardware/manipulators/piper/adapter.py | 36 ----- dimos/hardware/manipulators/spec.py | 26 ---- dimos/hardware/manipulators/xarm/adapter.py | 36 ----- dimos/manipulation/blueprints.py | 19 --- .../control/coordinator_client.py | 17 --- .../cartesian_motion_controller.py | 12 -- .../joint_trajectory_controller.py | 12 -- dimos/manipulation/manipulation_interface.py | 4 - dimos/manipulation/manipulation_module.py | 28 ---- dimos/manipulation/pick_and_place_module.py | 28 ---- .../planning/kinematics/jacobian_ik.py | 2 +- .../planning/kinematics/pinocchio_ik.py | 28 ---- .../planning/monitor/world_monitor.py | 22 +-- .../monitor/world_obstacle_monitor.py | 2 +- .../planning/planners/rrt_planner.py | 2 +- dimos/manipulation/planning/spec/types.py | 11 -- .../planning/world/drake_world.py | 22 ++- dimos/manipulation/test_manipulation_unit.py | 29 ---- dimos/memory/timeseries/base.py | 2 - dimos/memory/timeseries/legacy.py | 2 - .../temporal_memory/entity_graph_db.py | 14 +- .../frame_window_accumulator.py | 12 -- .../temporal_memory/temporal_memory.py | 28 ---- .../temporal_memory/temporal_state.py | 8 - .../test_temporal_memory_module.py | 45 ------ .../temporal_memory/window_analyzer.py | 16 -- dimos/protocol/pubsub/impl/shmpubsub.py | 11 +- dimos/protocol/pubsub/shm/ipc_factory.py | 20 --- .../service/system_configurator/base.py | 6 +- .../service/system_configurator/lcm.py | 6 +- dimos/protocol/service/test_lcmservice.py | 8 +- .../service/test_system_configurator.py | 16 +- dimos/skills/skills.py | 21 --- dimos/stream/frame_processor.py | 2 - dimos/teleop/phone/phone_teleop_module.py | 32 ---- dimos/teleop/quest/blueprints.py | 8 - dimos/teleop/quest/quest_teleop_module.py | 28 ---- dimos/test_no_sections.py | 143 ++++++++++++++++++ dimos/utils/cli/dtop.py | 23 --- dimos/utils/simple_controller.py | 6 - dimos/utils/test_data.py | 5 - docker/navigation/.env.hardware | 36 ----- docker/navigation/Dockerfile | 19 --- docker/navigation/docker-compose.dev.yml | 7 - .../manipulation/adding_a_custom_arm.md | 34 ----- 69 files changed, 196 insertions(+), 1111 deletions(-) create mode 100644 dimos/test_no_sections.py diff --git a/dimos/agents/mcp/mcp_adapter.py b/dimos/agents/mcp/mcp_adapter.py index 9b8cc5c4b9..213bf71e23 100644 --- a/dimos/agents/mcp/mcp_adapter.py +++ b/dimos/agents/mcp/mcp_adapter.py @@ -63,10 +63,6 @@ def __init__(self, url: str | None = None, timeout: int = DEFAULT_TIMEOUT) -> No self.url = url self.timeout = timeout - # ------------------------------------------------------------------ - # Low-level JSON-RPC - # ------------------------------------------------------------------ - def call(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: """Send a JSON-RPC request and return the parsed response. @@ -87,10 +83,6 @@ def call(self, method: str, params: dict[str, Any] | None = None) -> dict[str, A raise McpError(f"HTTP {resp.status_code}: {e}") from e return resp.json() # type: ignore[no-any-return] - # ------------------------------------------------------------------ - # MCP standard methods - # ------------------------------------------------------------------ - def initialize(self) -> dict[str, Any]: """Send ``initialize`` and return server info.""" return self.call("initialize") @@ -112,10 +104,6 @@ def call_tool_text(self, name: str, arguments: dict[str, Any] | None = None) -> return "" return content[0].get("text", str(content[0])) # type: ignore[no-any-return] - # ------------------------------------------------------------------ - # Readiness probes - # ------------------------------------------------------------------ - def wait_for_ready(self, timeout: float = 10.0, interval: float = 0.5) -> bool: """Poll until the MCP server responds, or return False on timeout.""" deadline = time.monotonic() + timeout @@ -148,10 +136,6 @@ def wait_for_down(self, timeout: float = 10.0, interval: float = 0.5) -> bool: time.sleep(interval) return False - # ------------------------------------------------------------------ - # Class methods for discovery - # ------------------------------------------------------------------ - @classmethod def from_run_entry(cls, entry: Any | None = None, timeout: int = DEFAULT_TIMEOUT) -> McpAdapter: """Create an adapter from a RunEntry, or discover the latest one. @@ -173,10 +157,6 @@ def from_run_entry(cls, entry: Any | None = None, timeout: int = DEFAULT_TIMEOUT url = f"http://localhost:{global_config.mcp_port}/mcp" return cls(url=url, timeout=timeout) - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - @staticmethod def _unwrap(response: dict[str, Any]) -> dict[str, Any]: """Extract the ``result`` from a JSON-RPC response, raising on error.""" diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index e5697542fb..9149de06ec 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -50,11 +50,6 @@ app.state.rpc_calls = {} -# --------------------------------------------------------------------------- -# JSON-RPC helpers -# --------------------------------------------------------------------------- - - def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]: return {"jsonrpc": "2.0", "id": req_id, "result": result} @@ -67,11 +62,6 @@ def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]: return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} -# --------------------------------------------------------------------------- -# JSON-RPC handlers (standard MCP protocol only) -# --------------------------------------------------------------------------- - - def _handle_initialize(req_id: Any) -> dict[str, Any]: return _jsonrpc_result( req_id, @@ -177,11 +167,6 @@ async def mcp_endpoint(request: Request) -> Response: return JSONResponse(result) -# --------------------------------------------------------------------------- -# McpServer Module -# --------------------------------------------------------------------------- - - class McpServer(Module): _uvicorn_server: uvicorn.Server | None = None _serve_future: concurrent.futures.Future[None] | None = None @@ -215,10 +200,6 @@ def on_system_modules(self, modules: list[RPCClient]) -> None: for skill_info in app.state.skills } - # ------------------------------------------------------------------ - # Introspection skills (exposed as MCP tools via tools/list) - # ------------------------------------------------------------------ - @skill def server_status(self) -> str: """Get MCP server status: main process PID, deployed modules, and skill count.""" diff --git a/dimos/agents_deprecated/agent.py b/dimos/agents_deprecated/agent.py index 0443b2cc94..1d48ce2fa4 100644 --- a/dimos/agents_deprecated/agent.py +++ b/dimos/agents_deprecated/agent.py @@ -68,9 +68,6 @@ _MAX_SAVED_FRAMES = 100 # Maximum number of frames to save -# ----------------------------------------------------------------------------- -# region Agent Base Class -# ----------------------------------------------------------------------------- class Agent: """Base agent that manages memory and subscriptions.""" @@ -105,12 +102,6 @@ def dispose_all(self) -> None: logger.info("No disposables to dispose.") -# endregion Agent Base Class - - -# ----------------------------------------------------------------------------- -# region LLMAgent Base Class (Generic LLM Agent) -# ----------------------------------------------------------------------------- class LLMAgent(Agent): """Generic LLM agent containing common logic for LLM-based agents. @@ -689,12 +680,6 @@ def dispose_all(self) -> None: self.response_subject.on_completed() -# endregion LLMAgent Base Class (Generic LLM Agent) - - -# ----------------------------------------------------------------------------- -# region OpenAIAgent Subclass (OpenAI-Specific Implementation) -# ----------------------------------------------------------------------------- class OpenAIAgent(LLMAgent): """OpenAI agent implementation that uses OpenAI's API for processing. @@ -914,4 +899,3 @@ def stream_query(self, query_text: str) -> Observable: # type: ignore[type-arg] ) -# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation) diff --git a/dimos/agents_deprecated/prompt_builder/impl.py b/dimos/agents_deprecated/prompt_builder/impl.py index 35c864062a..354057464f 100644 --- a/dimos/agents_deprecated/prompt_builder/impl.py +++ b/dimos/agents_deprecated/prompt_builder/impl.py @@ -148,7 +148,6 @@ def build( # type: ignore[no-untyped-def] # print("system_prompt: ", system_prompt) # print("rag_context: ", rag_context) - # region Token Counts if not override_token_limit: rag_token_cnt = self.tokenizer.token_count(rag_context) system_prompt_token_cnt = self.tokenizer.token_count(system_prompt) @@ -163,7 +162,6 @@ def build( # type: ignore[no-untyped-def] system_prompt_token_cnt = 0 user_query_token_cnt = 0 image_token_cnt = 0 - # endregion Token Counts # Create a component dictionary for dynamic allocation components = { diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py index 0384c69160..7c6036b20c 100644 --- a/dimos/control/blueprints.py +++ b/dimos/control/blueprints.py @@ -49,10 +49,6 @@ _XARM7_MODEL_PATH = LfsPath("xarm_description/urdf/xarm7/xarm7.urdf") -# ============================================================================= -# Single Arm Blueprints -# ============================================================================= - # Mock 7-DOF arm (for testing) coordinator_mock = control_coordinator( tick_rate=100.0, @@ -168,10 +164,6 @@ ) -# ============================================================================= -# Dual Arm Blueprints -# ============================================================================= - # Dual mock arms (7-DOF left, 6-DOF right) coordinator_dual_mock = control_coordinator( tick_rate=100.0, @@ -298,10 +290,6 @@ ) -# ============================================================================= -# Streaming Control Blueprints -# ============================================================================= - # XArm6 teleop - streaming position control coordinator_teleop_xarm6 = control_coordinator( tick_rate=100.0, @@ -399,11 +387,6 @@ ) -# ============================================================================= -# Cartesian IK Blueprints (internal Pinocchio IK solver) -# ============================================================================= - - # Mock 6-DOF arm with CartesianIK coordinator_cartesian_ik_mock = control_coordinator( tick_rate=100.0, @@ -471,10 +454,6 @@ ) -# ============================================================================= -# Teleop IK Blueprints (VR teleoperation with internal Pinocchio IK) -# ============================================================================= - # Single XArm7 with TeleopIK coordinator_teleop_xarm7 = control_coordinator( tick_rate=100.0, @@ -605,10 +584,6 @@ ) -# ============================================================================= -# Twist Base Blueprints (velocity-commanded platforms) -# ============================================================================= - # Mock holonomic twist base (3-DOF: vx, vy, wz) _base_joints = make_twist_base_joints("base") coordinator_mock_twist_base = control_coordinator( @@ -636,10 +611,6 @@ ) -# ============================================================================= -# Mobile Manipulation Blueprints (arm + twist base) -# ============================================================================= - # Mock arm (7-DOF) + mock holonomic base (3-DOF) _mm_base_joints = make_twist_base_joints("base") coordinator_mobile_manip_mock = control_coordinator( @@ -679,10 +650,6 @@ ) -# ============================================================================= -# Raw Blueprints (for programmatic setup) -# ============================================================================= - coordinator_basic = control_coordinator( tick_rate=100.0, publish_joint_state=True, @@ -694,10 +661,6 @@ ) -# ============================================================================= -# Exports -# ============================================================================= - __all__ = [ # Raw "coordinator_basic", diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 73e036e873..16f4e53f46 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -68,11 +68,6 @@ logger = setup_logger() -# ============================================================================= -# Configuration -# ============================================================================= - - @dataclass class TaskConfig: """Configuration for a control task. @@ -124,11 +119,6 @@ class ControlCoordinatorConfig(ModuleConfig): tasks: list[TaskConfig] = field(default_factory=lambda: []) -# ============================================================================= -# ControlCoordinator Module -# ============================================================================= - - class ControlCoordinator(Module[ControlCoordinatorConfig]): """Centralized control coordinator with per-joint arbitration. @@ -201,10 +191,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: logger.info(f"ControlCoordinator initialized at {self.config.tick_rate}Hz") - # ========================================================================= - # Config-based Setup - # ========================================================================= - def _setup_from_config(self) -> None: """Create hardware and tasks from config (called on start).""" hardware_added: list[str] = [] @@ -343,10 +329,6 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: else: raise ValueError(f"Unknown task type: {task_type}") - # ========================================================================= - # Hardware Management (RPC) - # ========================================================================= - @rpc def add_hardware( self, @@ -446,10 +428,6 @@ def get_joint_positions(self) -> dict[str, float]: positions[joint_name] = joint_state.position return positions - # ========================================================================= - # Task Management (RPC) - # ========================================================================= - @rpc def add_task(self, task: ControlTask) -> bool: """Register a task with the coordinator.""" @@ -492,10 +470,6 @@ def get_active_tasks(self) -> list[str]: with self._task_lock: return [name for name, task in self._tasks.items() if task.is_active()] - # ========================================================================= - # Streaming Control - # ========================================================================= - def _on_joint_command(self, msg: JointState) -> None: """Route incoming JointState to streaming tasks by joint name. @@ -603,10 +577,6 @@ def task_invoke( return getattr(task, method)(**kwargs) - # ========================================================================= - # Gripper - # ========================================================================= - @rpc def set_gripper_position(self, hardware_id: str, position: float) -> bool: """Set gripper position on a specific hardware device. @@ -640,10 +610,6 @@ def get_gripper_position(self, hardware_id: str) -> float | None: return None return hw.adapter.read_gripper_position() - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def start(self) -> None: """Start the coordinator control loop.""" diff --git a/dimos/control/task.py b/dimos/control/task.py index ecdf9ab7f4..c9ef03fbf0 100644 --- a/dimos/control/task.py +++ b/dimos/control/task.py @@ -37,10 +37,6 @@ from dimos.msgs.geometry_msgs import Pose, PoseStamped from dimos.teleop.quest.quest_types import Buttons -# ============================================================================= -# Data Types -# ============================================================================= - @dataclass(frozen=True) class ResourceClaim: @@ -168,11 +164,6 @@ def get_values(self) -> list[float] | None: return None -# ============================================================================= -# ControlTask Protocol -# ============================================================================= - - @runtime_checkable class ControlTask(Protocol): """Protocol for passive tasks that run within the coordinator. diff --git a/dimos/control/tasks/cartesian_ik_task.py b/dimos/control/tasks/cartesian_ik_task.py index 6ea5ddc55b..67d4e4ed52 100644 --- a/dimos/control/tasks/cartesian_ik_task.py +++ b/dimos/control/tasks/cartesian_ik_task.py @@ -255,10 +255,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: f"CartesianIKTask {self._name} preempted by {by_task} on joints {joints}" ) - # ========================================================================= - # Task-specific methods - # ========================================================================= - def on_cartesian_command(self, pose: Pose | PoseStamped, t_now: float) -> bool: """Handle incoming cartesian command (target EE pose). diff --git a/dimos/control/tasks/servo_task.py b/dimos/control/tasks/servo_task.py index b69b4dd099..50805bfa2c 100644 --- a/dimos/control/tasks/servo_task.py +++ b/dimos/control/tasks/servo_task.py @@ -159,10 +159,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: logger.warning(f"JointServoTask {self._name} preempted by {by_task} on joints {joints}") - # ========================================================================= - # Task-specific methods - # ========================================================================= - def set_target(self, positions: list[float], t_now: float) -> bool: """Set target joint positions. diff --git a/dimos/control/tasks/teleop_task.py b/dimos/control/tasks/teleop_task.py index ce63dc4006..115b455fe6 100644 --- a/dimos/control/tasks/teleop_task.py +++ b/dimos/control/tasks/teleop_task.py @@ -295,10 +295,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: logger.warning(f"TeleopIKTask {self._name} preempted by {by_task} on joints {joints}") - # ========================================================================= - # Task-specific methods - # ========================================================================= - def on_buttons(self, msg: Buttons) -> bool: """Press-and-hold engage: hold primary button to track, release to stop.""" is_left = self._config.hand == "left" diff --git a/dimos/control/tasks/trajectory_task.py b/dimos/control/tasks/trajectory_task.py index 4d2eaa188b..16a271018a 100644 --- a/dimos/control/tasks/trajectory_task.py +++ b/dimos/control/tasks/trajectory_task.py @@ -171,10 +171,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: self._state = TrajectoryState.ABORTED - # ========================================================================= - # Task-specific methods - # ========================================================================= - def execute(self, trajectory: JointTrajectory) -> bool: """Start executing a trajectory. diff --git a/dimos/control/tasks/velocity_task.py b/dimos/control/tasks/velocity_task.py index 163bc09827..5da475114d 100644 --- a/dimos/control/tasks/velocity_task.py +++ b/dimos/control/tasks/velocity_task.py @@ -191,10 +191,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: f"JointVelocityTask {self._name} preempted by {by_task} on joints {joints}" ) - # ========================================================================= - # Task-specific methods - # ========================================================================= - def set_velocities(self, velocities: list[float], t_now: float) -> bool: """Set target joint velocities. diff --git a/dimos/control/test_control.py b/dimos/control/test_control.py index 656678d167..a4b7e0a5bc 100644 --- a/dimos/control/test_control.py +++ b/dimos/control/test_control.py @@ -40,10 +40,6 @@ from dimos.hardware.manipulators.spec import ManipulatorAdapter from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint -# ============================================================================= -# Fixtures -# ============================================================================= - @pytest.fixture def mock_adapter(): @@ -112,11 +108,6 @@ def coordinator_state(): return CoordinatorState(joints=joints, t_now=time.perf_counter(), dt=0.01) -# ============================================================================= -# Test JointCommandOutput -# ============================================================================= - - class TestJointCommandOutput: def test_position_output(self): output = JointCommandOutput( @@ -153,11 +144,6 @@ def test_no_values_returns_none(self): assert output.get_values() is None -# ============================================================================= -# Test JointStateSnapshot -# ============================================================================= - - class TestJointStateSnapshot: def test_get_position(self): snapshot = JointStateSnapshot( @@ -171,11 +157,6 @@ def test_get_position(self): assert snapshot.get_position("nonexistent") is None -# ============================================================================= -# Test ConnectedHardware -# ============================================================================= - - class TestConnectedHardware: def test_joint_names_prefixed(self, connected_hardware): names = connected_hardware.joint_names @@ -206,11 +187,6 @@ def test_write_command(self, connected_hardware, mock_adapter): mock_adapter.write_joint_positions.assert_called() -# ============================================================================= -# Test JointTrajectoryTask -# ============================================================================= - - class TestJointTrajectoryTask: def test_initial_state(self, trajectory_task): assert trajectory_task.name == "test_traj" @@ -314,11 +290,6 @@ def test_progress(self, trajectory_task, simple_trajectory, coordinator_state): assert trajectory_task.get_progress(t_start + 1.0) == pytest.approx(1.0, abs=0.01) -# ============================================================================= -# Test Arbitration Logic -# ============================================================================= - - class TestArbitration: def test_single_task_wins(self): outputs = [ @@ -422,11 +393,6 @@ def test_non_overlapping_joints(self): assert winners["j4"][3] == "task2" -# ============================================================================= -# Test TickLoop -# ============================================================================= - - class TestTickLoop: def test_tick_loop_starts_and_stops(self, mock_adapter): component = HardwareComponent( @@ -498,11 +464,6 @@ def test_tick_loop_calls_compute(self, mock_adapter): assert mock_task.compute.call_count > 0 -# ============================================================================= -# Integration Test -# ============================================================================= - - class TestIntegration: def test_full_trajectory_execution(self, mock_adapter): component = HardwareComponent( diff --git a/dimos/control/tick_loop.py b/dimos/control/tick_loop.py index e0020a34da..e45a17030b 100644 --- a/dimos/control/tick_loop.py +++ b/dimos/control/tick_loop.py @@ -172,26 +172,19 @@ def _tick(self) -> None: self._last_tick_time = t_now self._tick_count += 1 - # === PHASE 1: READ ALL HARDWARE === joint_states = self._read_all_hardware() state = CoordinatorState(joints=joint_states, t_now=t_now, dt=dt) - # === PHASE 2: COMPUTE ALL ACTIVE TASKS === commands = self._compute_all_tasks(state) - # === PHASE 3: ARBITRATE (with mode validation) === joint_commands, preemptions = self._arbitrate(commands) - # === PHASE 4: NOTIFY PREEMPTIONS (once per task) === self._notify_preemptions(preemptions) - # === PHASE 5: ROUTE TO HARDWARE === hw_commands = self._route_to_hardware(joint_commands) - # === PHASE 6: WRITE TO HARDWARE === self._write_all_hardware(hw_commands) - # === PHASE 7: PUBLISH AGGREGATED STATE === if self._publish_callback: self._publish_joint_state(joint_states) diff --git a/dimos/core/daemon.py b/dimos/core/daemon.py index f4a19c9403..61060b2a73 100644 --- a/dimos/core/daemon.py +++ b/dimos/core/daemon.py @@ -31,10 +31,6 @@ logger = setup_logger() -# --------------------------------------------------------------------------- -# Health check (delegates to ModuleCoordinator.health_check) -# --------------------------------------------------------------------------- - def health_check(coordinator: ModuleCoordinator) -> bool: """Verify all coordinator workers are alive after build. @@ -45,11 +41,6 @@ def health_check(coordinator: ModuleCoordinator) -> bool: return coordinator.health_check() -# --------------------------------------------------------------------------- -# Daemonize (double-fork) -# --------------------------------------------------------------------------- - - def daemonize(log_dir: Path) -> None: """Double-fork daemonize the current process. @@ -83,11 +74,6 @@ def daemonize(log_dir: Path) -> None: devnull.close() -# --------------------------------------------------------------------------- -# Signal handler for clean shutdown -# --------------------------------------------------------------------------- - - def install_signal_handlers(entry: RunEntry, coordinator: ModuleCoordinator) -> None: """Install SIGTERM/SIGINT handlers that stop the coordinator and clean the registry.""" diff --git a/dimos/core/test_cli_stop_status.py b/dimos/core/test_cli_stop_status.py index c04d8d2499..5c628f6d92 100644 --- a/dimos/core/test_cli_stop_status.py +++ b/dimos/core/test_cli_stop_status.py @@ -72,11 +72,6 @@ def _entry(run_id: str, pid: int, blueprint: str = "test", **kwargs) -> RunEntry return e -# --------------------------------------------------------------------------- -# STATUS -# --------------------------------------------------------------------------- - - class TestStatusCLI: """Tests for `dimos status` command.""" @@ -132,11 +127,6 @@ def test_status_filters_dead_pids(self): assert "No running" in result.output -# --------------------------------------------------------------------------- -# STOP -# --------------------------------------------------------------------------- - - class TestStopCLI: """Tests for `dimos stop` command.""" diff --git a/dimos/core/test_daemon.py b/dimos/core/test_daemon.py index bd7c6b9ad8..f6dae51433 100644 --- a/dimos/core/test_daemon.py +++ b/dimos/core/test_daemon.py @@ -24,9 +24,6 @@ import pytest -# --------------------------------------------------------------------------- -# Registry tests -# --------------------------------------------------------------------------- from dimos.core import run_registry from dimos.core.run_registry import ( RunEntry, @@ -158,10 +155,6 @@ def test_port_conflict_no_false_positive(self, tmp_registry: Path): assert conflict is None -# --------------------------------------------------------------------------- -# Health check tests -# --------------------------------------------------------------------------- - from dimos.core.module_coordinator import ModuleCoordinator @@ -212,10 +205,6 @@ def test_partial_death(self): assert coord.health_check() is False -# --------------------------------------------------------------------------- -# Daemon tests -# --------------------------------------------------------------------------- - from dimos.core.daemon import daemonize, install_signal_handlers @@ -275,11 +264,6 @@ def test_signal_handler_tolerates_stop_error(self, tmp_registry: Path): assert not entry.registry_path.exists() -# --------------------------------------------------------------------------- -# dimos status tests -# --------------------------------------------------------------------------- - - class TestStatusCommand: """Tests for `dimos status` CLI command.""" @@ -327,11 +311,6 @@ def test_status_filters_dead(self, tmp_path, monkeypatch): assert len(entries) == 0 -# --------------------------------------------------------------------------- -# dimos stop tests -# --------------------------------------------------------------------------- - - class TestStopCommand: """Tests for `dimos stop` CLI command.""" diff --git a/dimos/core/test_e2e_daemon.py b/dimos/core/test_e2e_daemon.py index 7043d0384e..d8ac016faa 100644 --- a/dimos/core/test_e2e_daemon.py +++ b/dimos/core/test_e2e_daemon.py @@ -35,10 +35,6 @@ from dimos.core.stream import Out from dimos.robot.cli.dimos import main -# --------------------------------------------------------------------------- -# Lightweight test modules -# --------------------------------------------------------------------------- - class PingModule(Module): data: Out[str] @@ -54,11 +50,6 @@ def start(self): super().start() -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(autouse=True) def _ci_env(monkeypatch): """Set CI=1 to skip sysctl interactive prompt — scoped per test, not module.""" @@ -114,11 +105,6 @@ def registry_entry(): entry.remove() -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestDaemonE2E: """End-to-end daemon lifecycle with real workers.""" @@ -216,11 +202,6 @@ def test_stale_cleanup(self, coordinator, registry_entry): assert remaining[0].run_id == registry_entry.run_id -# --------------------------------------------------------------------------- -# E2E: CLI status + stop against real running blueprint -# --------------------------------------------------------------------------- - - @pytest.fixture() def live_blueprint(): """Build PingPong and register. Yields (coord, entry). Cleans up on teardown.""" diff --git a/dimos/core/test_mcp_integration.py b/dimos/core/test_mcp_integration.py index 543b9a7fbd..d7527e31f8 100644 --- a/dimos/core/test_mcp_integration.py +++ b/dimos/core/test_mcp_integration.py @@ -55,11 +55,6 @@ MCP_URL = f"http://localhost:{global_config.mcp_port}/mcp" -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(autouse=True) def _ci_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("CI", "1") @@ -121,11 +116,6 @@ def _adapter() -> McpAdapter: return McpAdapter() -# --------------------------------------------------------------------------- -# Tests -- read-only against a shared MCP server -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestMCPLifecycle: """MCP server lifecycle: start -> respond -> stop -> dead.""" @@ -323,11 +313,6 @@ def test_agent_send_cli(self, mcp_shared: ModuleCoordinator) -> None: assert "hello from CLI" in result.output -# --------------------------------------------------------------------------- -# Tests -- lifecycle management (own setup/teardown per test) -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestDaemonMCPRecovery: """Test MCP recovery after daemon crashes and restarts.""" diff --git a/dimos/core/tests/demo_devex.py b/dimos/core/tests/demo_devex.py index b9ac1393d7..243c870fab 100644 --- a/dimos/core/tests/demo_devex.py +++ b/dimos/core/tests/demo_devex.py @@ -98,9 +98,6 @@ def main() -> None: print(" Simulating: OpenClaw agent using DimOS") print("=" * 60) - # --------------------------------------------------------------- - # Step 1: dimos run stress-test --daemon - # --------------------------------------------------------------- section("Step 1: dimos run stress-test --daemon") result = run_dimos("run", "stress-test", "--daemon", timeout=60) print(f" stdout: {result.stdout.strip()[:200]}") @@ -131,9 +128,6 @@ def main() -> None: print(" Cannot continue without MCP. Exiting.") sys.exit(1) - # --------------------------------------------------------------- - # Step 2: dimos status - # --------------------------------------------------------------- section("Step 2: dimos status") result = run_dimos("status") print(f" output: {result.stdout.strip()[:300]}") @@ -145,9 +139,6 @@ def main() -> None: p(f"Status unclear (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 3: dimos mcp list-tools - # --------------------------------------------------------------- section("Step 3: dimos mcp list-tools") result = run_dimos("mcp", "list-tools") if result.returncode == 0: @@ -167,9 +158,6 @@ def main() -> None: p(f"list-tools failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 4: dimos mcp call echo --arg message=hello - # --------------------------------------------------------------- section("Step 4: dimos mcp call echo --arg message=hello") result = run_dimos("mcp", "call", "echo", "--arg", "message=hello-from-devex-test") if result.returncode == 0 and "hello-from-devex-test" in result.stdout: @@ -178,9 +166,6 @@ def main() -> None: p(f"echo call failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 5: dimos mcp status - # --------------------------------------------------------------- section("Step 5: dimos mcp status") result = run_dimos("mcp", "status") if result.returncode == 0: @@ -196,9 +181,6 @@ def main() -> None: p(f"mcp status failed (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 6: dimos mcp modules - # --------------------------------------------------------------- section("Step 6: dimos mcp modules") result = run_dimos("mcp", "modules") if result.returncode == 0: @@ -213,9 +195,6 @@ def main() -> None: p(f"mcp modules failed (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 7: dimos agent-send "hello" - # --------------------------------------------------------------- section("Step 7: dimos agent-send 'what tools do you have?'") result = run_dimos("agent-send", "what tools do you have?") if result.returncode == 0: @@ -224,9 +203,6 @@ def main() -> None: p(f"agent-send failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 8: Check logs - # --------------------------------------------------------------- section("Step 8: Check per-run logs") log_base = os.path.expanduser("~/.local/state/dimos/logs") if os.path.isdir(log_base): @@ -257,9 +233,6 @@ def main() -> None: p(f"Log base dir not found: {log_base}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 9: dimos stop - # --------------------------------------------------------------- section("Step 9: dimos stop") result = run_dimos("stop") print(f" output: {result.stdout.strip()[:200]}") @@ -272,9 +245,6 @@ def main() -> None: # Wait for shutdown time.sleep(2) - # --------------------------------------------------------------- - # Step 10: dimos status (verify stopped) - # --------------------------------------------------------------- section("Step 10: dimos status (verify stopped)") result = run_dimos("status") print(f" output: {result.stdout.strip()[:200]}") @@ -288,9 +258,6 @@ def main() -> None: p(f"Unexpected status after stop (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Summary - # --------------------------------------------------------------- print("\n" + "=" * 60) if failures == 0: print(" \u2705 FULL DEVELOPER EXPERIENCE TEST PASSED") diff --git a/dimos/hardware/drive_trains/flowbase/adapter.py b/dimos/hardware/drive_trains/flowbase/adapter.py index 5b5563792d..ec96365c78 100644 --- a/dimos/hardware/drive_trains/flowbase/adapter.py +++ b/dimos/hardware/drive_trains/flowbase/adapter.py @@ -62,10 +62,6 @@ def __init__(self, dof: int = 3, address: str | None = None, **_: object) -> Non # Last commanded velocities (in standard frame, before negation) self._last_velocities = [0.0, 0.0, 0.0] - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to FlowBase controller via Portal RPC.""" try: @@ -98,18 +94,10 @@ def is_connected(self) -> bool: """Check if connected to FlowBase.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_dof(self) -> int: """FlowBase is always 3 DOF (vx, vy, wz).""" return 3 - # ========================================================================= - # State Reading - # ========================================================================= - def read_velocities(self) -> list[float]: """Return last commanded velocities (FlowBase doesn't report actual).""" with self._lock: @@ -134,10 +122,6 @@ def read_odometry(self) -> list[float] | None: logger.error(f"Error reading FlowBase odometry: {e}") return None - # ========================================================================= - # Control - # ========================================================================= - def write_velocities(self, velocities: list[float]) -> bool: """Send velocity command to FlowBase. @@ -165,10 +149,6 @@ def write_stop(self) -> bool: return False return self._send_velocity(0.0, 0.0, 0.0) - # ========================================================================= - # Enable/Disable - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable the platform (FlowBase is always enabled when connected).""" self._enabled = enable @@ -178,10 +158,6 @@ def read_enabled(self) -> bool: """Check if platform is enabled.""" return self._enabled - # ========================================================================= - # Internal - # ========================================================================= - def _send_velocity(self, vx: float, vy: float, wz: float) -> bool: """Send raw velocity to FlowBase via Portal RPC.""" try: diff --git a/dimos/hardware/drive_trains/mock/adapter.py b/dimos/hardware/drive_trains/mock/adapter.py index 2091ec59d0..d6131305e6 100644 --- a/dimos/hardware/drive_trains/mock/adapter.py +++ b/dimos/hardware/drive_trains/mock/adapter.py @@ -48,10 +48,6 @@ def __init__(self, dof: int = 3, **_: object) -> None: self._enabled = False self._connected = False - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Simulate connection.""" self._connected = True @@ -65,18 +61,10 @@ def is_connected(self) -> bool: """Check mock connection status.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_dof(self) -> int: """Return DOF.""" return self._dof - # ========================================================================= - # State Reading - # ========================================================================= - def read_velocities(self) -> list[float]: """Return mock velocities.""" return self._velocities.copy() @@ -87,10 +75,6 @@ def read_odometry(self) -> list[float] | None: return None return self._odometry.copy() - # ========================================================================= - # Control - # ========================================================================= - def write_velocities(self, velocities: list[float]) -> bool: """Set mock velocities.""" if len(velocities) != self._dof: @@ -103,10 +87,6 @@ def write_stop(self) -> bool: self._velocities = [0.0] * self._dof return True - # ========================================================================= - # Enable/Disable - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable mock platform.""" self._enabled = enable @@ -116,10 +96,6 @@ def read_enabled(self) -> bool: """Check mock enable state.""" return self._enabled - # ========================================================================= - # Test Helpers (not part of Protocol) - # ========================================================================= - def set_odometry(self, odometry: list[float] | None) -> None: """Set odometry directly for testing.""" self._odometry = list(odometry) if odometry is not None else None diff --git a/dimos/hardware/drive_trains/spec.py b/dimos/hardware/drive_trains/spec.py index 0b288edfd4..1380ef1fa9 100644 --- a/dimos/hardware/drive_trains/spec.py +++ b/dimos/hardware/drive_trains/spec.py @@ -35,8 +35,6 @@ class TwistBaseAdapter(Protocol): - Angle: radians """ - # --- Connection --- - def connect(self) -> bool: """Connect to hardware. Returns True on success.""" ... @@ -49,14 +47,10 @@ def is_connected(self) -> bool: """Check if connected.""" ... - # --- Info --- - def get_dof(self) -> int: """Get number of velocity DOFs (e.g., 3 for holonomic, 2 for differential).""" ... - # --- State Reading --- - def read_velocities(self) -> list[float]: """Read current velocities in virtual joint order (m/s or rad/s).""" ... @@ -69,8 +63,6 @@ def read_odometry(self) -> list[float] | None: """ ... - # --- Control --- - def write_velocities(self, velocities: list[float]) -> bool: """Command velocities in virtual joint order. Returns success.""" ... @@ -79,8 +71,6 @@ def write_stop(self) -> bool: """Stop all motion immediately (zero velocities).""" ... - # --- Enable/Disable --- - def write_enable(self, enable: bool) -> bool: """Enable or disable the platform. Returns success.""" ... diff --git a/dimos/hardware/manipulators/mock/adapter.py b/dimos/hardware/manipulators/mock/adapter.py index ff299669f7..53c53c722d 100644 --- a/dimos/hardware/manipulators/mock/adapter.py +++ b/dimos/hardware/manipulators/mock/adapter.py @@ -66,10 +66,6 @@ def __init__(self, dof: int = 6, **_: object) -> None: self._error_code: int = 0 self._error_message: str = "" - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Simulate connection.""" self._connected = True @@ -83,10 +79,6 @@ def is_connected(self) -> bool: """Check mock connection status.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Return mock info.""" return ManipulatorInfo( @@ -109,10 +101,6 @@ def get_limits(self) -> JointLimits: velocity_max=[1.0] * self._dof, ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set mock control mode.""" self._control_mode = mode @@ -122,10 +110,6 @@ def get_control_mode(self) -> ControlMode: """Get mock control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Return mock joint positions.""" return self._positions.copy() @@ -151,10 +135,6 @@ def read_error(self) -> tuple[int, str]: """Return mock error.""" return self._error_code, self._error_message - # ========================================================================= - # Motion Control - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -178,10 +158,6 @@ def write_stop(self) -> bool: self._velocities = [0.0] * self._dof return True - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable mock servos.""" self._enabled = enable @@ -197,10 +173,6 @@ def write_clear_errors(self) -> bool: self._error_message = "" return True - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Return mock cartesian position.""" return self._cartesian_position.copy() @@ -214,10 +186,6 @@ def write_cartesian_position( self._cartesian_position.update(pose) return True - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Return mock gripper position.""" return self._gripper_position @@ -227,18 +195,10 @@ def write_gripper_position(self, position: float) -> bool: self._gripper_position = position return True - # ========================================================================= - # Force/Torque (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Return mock F/T sensor data (not supported in mock).""" return None - # ========================================================================= - # Test Helpers (not part of Protocol) - # ========================================================================= - def set_error(self, code: int, message: str) -> None: """Inject an error for testing error handling.""" self._error_code = code diff --git a/dimos/hardware/manipulators/piper/adapter.py b/dimos/hardware/manipulators/piper/adapter.py index 68b5769a95..49ed68bcf9 100644 --- a/dimos/hardware/manipulators/piper/adapter.py +++ b/dimos/hardware/manipulators/piper/adapter.py @@ -75,10 +75,6 @@ def __init__( self._enabled: bool = False self._control_mode: ControlMode = ControlMode.POSITION - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to Piper via CAN bus.""" try: @@ -139,10 +135,6 @@ def is_connected(self) -> bool: except Exception: return False - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Get Piper information.""" firmware_version = None @@ -176,10 +168,6 @@ def get_limits(self) -> JointLimits: velocity_max=max_vel, ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set Piper control mode via MotionCtrl_2.""" if not self._sdk: @@ -207,10 +195,6 @@ def get_control_mode(self) -> ControlMode: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Read joint positions (Piper units -> radians).""" if not self._sdk: @@ -295,10 +279,6 @@ def read_error(self) -> tuple[int, str]: return 0, "" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -366,10 +346,6 @@ def write_stop(self) -> bool: # Fallback: disable arm return self.write_enable(False) - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" if not self._sdk: @@ -427,10 +403,6 @@ def write_clear_errors(self) -> bool: time.sleep(0.1) return self.write_enable(True) - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose. @@ -470,10 +442,6 @@ def write_cartesian_position( # Cartesian control not commonly supported in Piper SDK return False - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Read gripper position (percentage -> meters).""" if not self._sdk: @@ -508,10 +476,6 @@ def write_gripper_position(self, position: float) -> bool: return False - # ========================================================================= - # Force/Torque Sensor (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Read F/T sensor data. diff --git a/dimos/hardware/manipulators/spec.py b/dimos/hardware/manipulators/spec.py index ff4d38c54f..ed63a21e82 100644 --- a/dimos/hardware/manipulators/spec.py +++ b/dimos/hardware/manipulators/spec.py @@ -28,10 +28,6 @@ from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -# ============================================================================ -# SHARED TYPES -# ============================================================================ - class DriverStatus(Enum): """Status returned by driver operations.""" @@ -83,11 +79,6 @@ def default_base_transform() -> Transform: ) -# ============================================================================ -# ADAPTER PROTOCOL -# ============================================================================ - - @runtime_checkable class ManipulatorAdapter(Protocol): """Protocol for hardware-specific IO. @@ -100,8 +91,6 @@ class ManipulatorAdapter(Protocol): - Force: Newtons """ - # --- Connection --- - def connect(self) -> bool: """Connect to hardware. Returns True on success.""" ... @@ -114,8 +103,6 @@ def is_connected(self) -> bool: """Check if connected.""" ... - # --- Info --- - def get_info(self) -> ManipulatorInfo: """Get manipulator info (vendor, model, DOF).""" ... @@ -128,8 +115,6 @@ def get_limits(self) -> JointLimits: """Get joint limits.""" ... - # --- Control Mode --- - def set_control_mode(self, mode: ControlMode) -> bool: """Set control mode (position, velocity, torque, cartesian, etc). @@ -152,8 +137,6 @@ def get_control_mode(self) -> ControlMode: """ ... - # --- State Reading --- - def read_joint_positions(self) -> list[float]: """Read current joint positions (radians).""" ... @@ -174,8 +157,6 @@ def read_error(self) -> tuple[int, str]: """Read error code and message. (0, '') means no error.""" ... - # --- Motion Control (Joint Space) --- - def write_joint_positions( self, positions: list[float], @@ -192,8 +173,6 @@ def write_stop(self) -> bool: """Stop all motion immediately.""" ... - # --- Servo Control --- - def write_enable(self, enable: bool) -> bool: """Enable or disable servos. Returns success.""" ... @@ -206,7 +185,6 @@ def write_clear_errors(self) -> bool: """Clear error state. Returns success.""" ... - # --- Optional: Cartesian Control --- # Return None/False if not supported def read_cartesian_position(self) -> dict[str, float] | None: @@ -234,8 +212,6 @@ def write_cartesian_position( """ ... - # --- Optional: Gripper --- - def read_gripper_position(self) -> float | None: """Read gripper position (meters). None if no gripper.""" ... @@ -244,8 +220,6 @@ def write_gripper_position(self, position: float) -> bool: """Command gripper position. False if no gripper.""" ... - # --- Optional: Force/Torque Sensor --- - def read_force_torque(self) -> list[float] | None: """Read F/T sensor [fx, fy, fz, tx, ty, tz]. None if no sensor.""" ... diff --git a/dimos/hardware/manipulators/xarm/adapter.py b/dimos/hardware/manipulators/xarm/adapter.py index 80cc8edb38..3e24c530d1 100644 --- a/dimos/hardware/manipulators/xarm/adapter.py +++ b/dimos/hardware/manipulators/xarm/adapter.py @@ -64,10 +64,6 @@ def __init__(self, address: str, dof: int = 6, **_: object) -> None: self._control_mode: ControlMode = ControlMode.POSITION self._gripper_enabled: bool = False - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to XArm via TCP/IP.""" try: @@ -98,10 +94,6 @@ def is_connected(self) -> bool: """Check if connected to XArm.""" return self._arm is not None and self._arm.connected - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Get XArm information.""" return ManipulatorInfo( @@ -124,10 +116,6 @@ def get_limits(self) -> JointLimits: velocity_max=[math.pi] * self._dof, # ~180 deg/s ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set XArm control mode. @@ -161,10 +149,6 @@ def get_control_mode(self) -> ControlMode: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Read joint positions (degrees -> radians).""" if not self._arm: @@ -214,10 +198,6 @@ def read_error(self) -> tuple[int, str]: return 0, "" return code, f"XArm error {code}" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -263,10 +243,6 @@ def write_stop(self) -> bool: code: int = self._arm.emergency_stop() return code == 0 - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" if not self._arm: @@ -289,10 +265,6 @@ def write_clear_errors(self) -> bool: code: int = self._arm.clean_error() return code == 0 - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose (mm -> meters, degrees -> radians).""" if not self._arm: @@ -331,10 +303,6 @@ def write_cartesian_position( ) return code == 0 - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Read gripper position (mm -> meters).""" if not self._arm: @@ -359,10 +327,6 @@ def write_gripper_position(self, position: float) -> bool: code: int = self._arm.set_gripper_position(pos_mm, wait=False) return code == 0 - # ========================================================================= - # Force/Torque Sensor (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Read F/T sensor data if available.""" if not self._arm: diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 97657b9cae..7a0eefb37a 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -45,10 +45,6 @@ from dimos.robot.foxglove_bridge import foxglove_bridge # TODO: migrate to rerun from dimos.utils.data import get_data -# ============================================================================= -# Pose Helpers -# ============================================================================= - def _make_base_pose( x: float = 0.0, @@ -70,11 +66,6 @@ def _make_base_pose( ) -# ============================================================================= -# URDF Helpers -# ============================================================================= - - def _get_xarm_urdf_path() -> Path: """Get path to xarm URDF.""" return get_data("xarm_description") / "urdf/xarm_device.urdf.xacro" @@ -133,11 +124,6 @@ def _get_piper_package_paths() -> dict[str, Path]: ] -# ============================================================================= -# Robot Configs -# ============================================================================= - - def _make_xarm6_config( name: str = "arm", y_offset: float = 0.0, @@ -283,11 +269,6 @@ def _make_piper_config( ) -# ============================================================================= -# Blueprints -# ============================================================================= - - # Single XArm6 planner (standalone, no coordinator) xarm6_planner_only = manipulation_module( robots=[_make_xarm6_config()], diff --git a/dimos/manipulation/control/coordinator_client.py b/dimos/manipulation/control/coordinator_client.py index 4e277fae97..cbaad28df2 100644 --- a/dimos/manipulation/control/coordinator_client.py +++ b/dimos/manipulation/control/coordinator_client.py @@ -98,10 +98,6 @@ def stop(self) -> None: """Stop the RPC client.""" self._rpc.stop_rpc_client() - # ========================================================================= - # Query methods (RPC calls) - # ========================================================================= - def list_hardware(self) -> list[str]: """List all hardware IDs.""" return self._rpc.list_hardware() or [] @@ -129,10 +125,6 @@ def get_trajectory_status(self, task_name: str) -> dict[str, Any]: return {"state": int(result), "task": task_name} return {} - # ========================================================================= - # Trajectory execution (via task_invoke) - # ========================================================================= - def execute_trajectory(self, task_name: str, trajectory: JointTrajectory) -> bool: """Execute a trajectory on a task via task_invoke.""" result = self._rpc.task_invoke(task_name, "execute", {"trajectory": trajectory}) @@ -143,10 +135,6 @@ def cancel_trajectory(self, task_name: str) -> bool: result = self._rpc.task_invoke(task_name, "cancel", {}) return bool(result) - # ========================================================================= - # Task selection and setup - # ========================================================================= - def select_task(self, task_name: str) -> bool: """ Select a task and setup its trajectory generator. @@ -248,11 +236,6 @@ def set_acceleration_limit(self, acceleration: float, task_name: str | None = No gen.set_limits(gen.max_velocity, acceleration) -# ============================================================================= -# Interactive CLI -# ============================================================================= - - def parse_joint_input(line: str, num_joints: int) -> list[float] | None: """Parse joint positions from user input (degrees by default, 'r' suffix for radians).""" parts = line.strip().split() diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py index 7dd7e8c119..a12fb44a96 100644 --- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py +++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py @@ -272,10 +272,6 @@ def stop(self) -> None: super().stop() logger.info("CartesianMotionController stopped") - # ========================================================================= - # RPC Methods - High-level control - # ========================================================================= - @rpc def set_target_pose( self, position: list[float], orientation: list[float], frame_id: str = "world" @@ -349,10 +345,6 @@ def is_converged(self) -> bool: and ori_error < self.config.orientation_tolerance ) - # ========================================================================= - # Private Methods - Callbacks - # ========================================================================= - def _on_joint_state(self, msg: JointState) -> None: """Callback when new joint state is received.""" logger.debug(f"Received joint_state: {len(msg.position)} joints") @@ -372,10 +364,6 @@ def _on_target_pose(self, msg: PoseStamped) -> None: self._is_tracking = True logger.debug(f"New target received: {msg}") - # ========================================================================= - # Private Methods - Control Loop - # ========================================================================= - def _control_loop(self) -> None: """ Main control loop running at control_frequency Hz. diff --git a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py index ebc6f3f53c..ed62a7345e 100644 --- a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py +++ b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py @@ -153,10 +153,6 @@ def stop(self) -> None: super().stop() logger.info("JointTrajectoryController stopped") - # ========================================================================= - # RPC Methods - Action-server-like interface - # ========================================================================= - @rpc def execute_trajectory(self, trajectory: JointTrajectory) -> bool: """ @@ -270,10 +266,6 @@ def get_status(self) -> TrajectoryStatus: error=self._error_message, ) - # ========================================================================= - # Callbacks - # ========================================================================= - def _on_joint_state(self, msg: JointState) -> None: """Callback for joint state feedback.""" self._latest_joint_state = msg @@ -289,10 +281,6 @@ def _on_trajectory(self, msg: JointTrajectory) -> None: ) self.execute_trajectory(msg) - # ========================================================================= - # Execution Loop - # ========================================================================= - def _execution_loop(self) -> None: """ Main execution loop running at control_frequency Hz. diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py index 524562520d..c60cbfd9c6 100644 --- a/dimos/manipulation/manipulation_interface.py +++ b/dimos/manipulation/manipulation_interface.py @@ -157,8 +157,6 @@ def update_task_result(self, task_id: str, result: dict[str, Any]) -> Manipulati return task return None - # === Perception stream methods === - def _setup_perception_subscription(self) -> None: """ Set up subscription to perception stream if available. @@ -239,8 +237,6 @@ def cleanup_perception_subscription(self) -> None: self.stream_subscription.dispose() self.stream_subscription = None - # === Utility methods === - def clear(self) -> None: """ Clear all manipulation tasks and agent constraints. diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index cab6c9f173..f064130965 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -278,10 +278,6 @@ def _tf_publish_loop(self) -> None: self._tf_stop_event.wait(period) - # ========================================================================= - # RPC Methods - # ========================================================================= - @rpc def get_state(self) -> str: """Get current manipulation state name.""" @@ -356,10 +352,6 @@ def is_collision_free(self, joints: list[float], robot_name: RobotName | None = return self._world_monitor.is_state_valid(robot_id, joint_state) return False - # ========================================================================= - # Plan/Preview/Execute Workflow RPC Methods - # ========================================================================= - def _begin_planning( self, robot_name: RobotName | None = None ) -> tuple[RobotName, WorldRobotID] | None: @@ -630,10 +622,6 @@ def set_init_joints_to_current(self, robot_name: RobotName | None = None) -> boo ) return True - # ========================================================================= - # Coordinator Integration RPC Methods - # ========================================================================= - def _get_coordinator_client(self) -> RPCClient | None: """Get or create coordinator RPC client (lazy init).""" if not any( @@ -780,10 +768,6 @@ def remove_obstacle(self, obstacle_id: str) -> bool: return False return self._world_monitor.remove_obstacle(obstacle_id) - # ========================================================================= - # Gripper Methods - # ========================================================================= - def _get_gripper_hardware_id(self, robot_name: RobotName | None = None) -> str | None: """Get gripper hardware ID for a robot.""" robot = self._get_robot(robot_name) @@ -856,10 +840,6 @@ def close_gripper(self, robot_name: str | None = None) -> str: return "Gripper closed" return "Error: Failed to close gripper" - # ========================================================================= - # Skill Helpers (internal) - # ========================================================================= - def _wait_for_trajectory_completion( self, robot_name: RobotName | None = None, timeout: float = 60.0, poll_interval: float = 0.2 ) -> bool: @@ -944,10 +924,6 @@ def _preview_execute_wait( return None - # ========================================================================= - # Short-Horizon Skills — Single-step actions - # ========================================================================= - @skill def get_robot_state(self, robot_name: str | None = None) -> str: """Get current robot state: joint positions, end-effector pose, and gripper. @@ -1132,10 +1108,6 @@ def go_init(self, robot_name: str | None = None) -> str: return "Reached init position" - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def stop(self) -> None: """Stop the manipulation module.""" diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index 2016abeb4f..6d6ad1042e 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -102,10 +102,6 @@ def __init__(self, **kwargs: Any) -> None: # so pick/place use this stable snapshot instead. self._detection_snapshot: list[DetObject] = [] - # ========================================================================= - # Lifecycle (perception integration) - # ========================================================================= - @rpc def start(self) -> None: """Start the pick-and-place module (adds perception subscriptions).""" @@ -130,10 +126,6 @@ def _on_objects(self, objects: list[DetObject]) -> None: except Exception as e: logger.error(f"Exception in _on_objects: {e}") - # ========================================================================= - # Perception RPC Methods - # ========================================================================= - @rpc def refresh_obstacles(self, min_duration: float = 0.0) -> list[dict[str, Any]]: """Refresh perception obstacles. Returns the list of obstacles added. @@ -182,10 +174,6 @@ def list_added_obstacles(self) -> list[dict[str, Any]]: return [] return self._world_monitor.list_added_obstacles() - # ========================================================================= - # GraspGen - # ========================================================================= - def _get_graspgen(self) -> DockerRunner: """Get or create GraspGen Docker module (lazy init, thread-safe).""" # Fast path: already initialized (no lock needed for read) @@ -250,10 +238,6 @@ def generate_grasps( logger.error(f"Grasp generation failed: {e}") return None - # ========================================================================= - # Pick/Place Helpers - # ========================================================================= - def _compute_pre_grasp_pose(self, grasp_pose: Pose, offset: float = 0.10) -> Pose: """Compute a pre-grasp pose offset along the approach direction (local -Z). @@ -321,10 +305,6 @@ def _generate_grasps_for_pick( logger.info(f"Heuristic grasp for '{object_name}' at ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") return [grasp_pose] - # ========================================================================= - # Perception Skills - # ========================================================================= - @skill def get_scene_info(self, robot_name: str | None = None) -> str: """Get current robot state, detected objects, and scene information. @@ -410,10 +390,6 @@ def scan_objects(self, min_duration: float = 1.0, robot_name: str | None = None) return "\n".join(lines) - # ========================================================================= - # Long-Horizon Skills — Pick and Place - # ========================================================================= - @skill def pick( self, @@ -602,10 +578,6 @@ def pick_and_place( # Place phase return self.place(place_x, place_y, place_z, robot_name) - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def stop(self) -> None: """Stop the pick-and-place module (cleanup GraspGen + delegate to base).""" diff --git a/dimos/manipulation/planning/kinematics/jacobian_ik.py b/dimos/manipulation/planning/kinematics/jacobian_ik.py index 5f80642058..c756045d36 100644 --- a/dimos/manipulation/planning/kinematics/jacobian_ik.py +++ b/dimos/manipulation/planning/kinematics/jacobian_ik.py @@ -395,7 +395,7 @@ def solve_differential_position_only( return JointState(name=joint_names, velocity=q_dot.tolist()) -# ============= Result Helpers ============= +# Result Helpers def _create_success_result( diff --git a/dimos/manipulation/planning/kinematics/pinocchio_ik.py b/dimos/manipulation/planning/kinematics/pinocchio_ik.py index 4224dda556..ff1c2dcc2a 100644 --- a/dimos/manipulation/planning/kinematics/pinocchio_ik.py +++ b/dimos/manipulation/planning/kinematics/pinocchio_ik.py @@ -49,11 +49,6 @@ logger = setup_logger() -# ============================================================================= -# Configuration -# ============================================================================= - - @dataclass class PinocchioIKConfig: """Configuration for the Pinocchio IK solver. @@ -73,11 +68,6 @@ class PinocchioIKConfig: max_velocity: float = 10.0 -# ============================================================================= -# PinocchioIK Solver -# ============================================================================= - - class PinocchioIK: """Pinocchio-based damped least-squares IK solver. @@ -162,10 +152,6 @@ def ee_joint_id(self) -> int: """End-effector joint ID.""" return self._ee_joint_id - # ========================================================================= - # Core IK - # ========================================================================= - def solve( self, target_pose: pinocchio.SE3, @@ -208,10 +194,6 @@ def solve( return q, False, final_err - # ========================================================================= - # Forward Kinematics - # ========================================================================= - def forward_kinematics(self, joint_positions: NDArray[np.floating[Any]]) -> pinocchio.SE3: """Compute end-effector pose from joint positions. @@ -225,11 +207,6 @@ def forward_kinematics(self, joint_positions: NDArray[np.floating[Any]]) -> pino return self._data.oMi[self._ee_joint_id].copy() -# ============================================================================= -# Pose Conversion Helpers -# ============================================================================= - - def pose_to_se3(pose: Pose | PoseStamped) -> pinocchio.SE3: """Convert Pose or PoseStamped to pinocchio SE3""" @@ -239,11 +216,6 @@ def pose_to_se3(pose: Pose | PoseStamped) -> pinocchio.SE3: return pinocchio.SE3(rotation, position) -# ============================================================================= -# Safety Utilities -# ============================================================================= - - def check_joint_delta( q_new: NDArray[np.floating[Any]], q_current: NDArray[np.floating[Any]], diff --git a/dimos/manipulation/planning/monitor/world_monitor.py b/dimos/manipulation/planning/monitor/world_monitor.py index 33017957dc..cca2dda013 100644 --- a/dimos/manipulation/planning/monitor/world_monitor.py +++ b/dimos/manipulation/planning/monitor/world_monitor.py @@ -66,7 +66,7 @@ def __init__( self._viz_stop_event = threading.Event() self._viz_rate_hz: float = 10.0 - # ============= Robot Management ============= + # Robot Management def add_robot(self, config: RobotModelConfig) -> WorldRobotID: """Add a robot. Returns robot_id.""" @@ -93,7 +93,7 @@ def get_joint_limits( with self._lock: return self._world.get_joint_limits(robot_id) - # ============= Obstacle Management ============= + # Obstacle Management def add_obstacle(self, obstacle: Obstacle) -> str: """Add an obstacle. Returns obstacle_id.""" @@ -110,7 +110,7 @@ def clear_obstacles(self) -> None: with self._lock: self._world.clear_obstacles() - # ============= Monitor Control ============= + # Monitor Control def start_state_monitor( self, @@ -181,7 +181,7 @@ def stop_all_monitors(self) -> None: self._world.close() - # ============= Message Handlers ============= + # Message Handlers def on_joint_state(self, msg: JointState, robot_id: WorldRobotID | None = None) -> None: """Handle joint state message. Broadcasts to all monitors if robot_id is None.""" @@ -252,7 +252,7 @@ def list_added_obstacles(self) -> list[dict[str, Any]]: return self._obstacle_monitor.list_added_obstacles() return [] - # ============= State Access ============= + # State Access def get_current_joint_state(self, robot_id: WorldRobotID) -> JointState | None: """Get current joint state. Returns None if not yet received.""" @@ -294,7 +294,7 @@ def is_state_stale(self, robot_id: WorldRobotID, max_age: float = 1.0) -> bool: return self._state_monitors[robot_id].is_state_stale(max_age) return True - # ============= Context Management ============= + # Context Management @contextmanager def scratch_context(self) -> Generator[Any, None, None]: @@ -306,7 +306,7 @@ def get_live_context(self) -> Any: """Get live context. Prefer scratch_context() for planning.""" return self._world.get_live_context() - # ============= Collision Checking ============= + # Collision Checking def is_state_valid(self, robot_id: WorldRobotID, joint_state: JointState) -> bool: """Check if configuration is collision-free.""" @@ -340,7 +340,7 @@ def get_min_distance(self, robot_id: WorldRobotID) -> float: with self._world.scratch_context() as ctx: return self._world.get_min_distance(ctx, robot_id) - # ============= Kinematics ============= + # Kinematics def get_ee_pose( self, robot_id: WorldRobotID, joint_state: JointState | None = None @@ -394,7 +394,7 @@ def get_jacobian(self, robot_id: WorldRobotID, joint_state: JointState) -> NDArr self._world.set_joint_state(ctx, robot_id, joint_state) return self._world.get_jacobian(ctx, robot_id) - # ============= Lifecycle ============= + # Lifecycle def finalize(self) -> None: """Finalize world. Must be called before collision checking.""" @@ -407,7 +407,7 @@ def is_finalized(self) -> bool: """Check if world is finalized.""" return self._world.is_finalized - # ============= Visualization ============= + # Visualization def get_visualization_url(self) -> str | None: """Get visualization URL or None if not enabled.""" @@ -466,7 +466,7 @@ def _visualization_loop(self) -> None: logger.debug(f"Visualization publish failed: {e}") time.sleep(period) - # ============= Direct World Access ============= + # Direct World Access @property def world(self) -> WorldSpec: diff --git a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py index a96d3efaf6..4f69afad68 100644 --- a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py +++ b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py @@ -406,7 +406,7 @@ def remove_obstacle_callback( if callback in self._obstacle_callbacks: self._obstacle_callbacks.remove(callback) - # ============= Object-Based Perception (from ObjectDB) ============= + # Object-Based Perception (from ObjectDB) def on_objects(self, objects: list[object]) -> None: """Cache objects from ObjectDB (preserves stable object_id). diff --git a/dimos/manipulation/planning/planners/rrt_planner.py b/dimos/manipulation/planning/planners/rrt_planner.py index f2be8736d5..71204488c4 100644 --- a/dimos/manipulation/planning/planners/rrt_planner.py +++ b/dimos/manipulation/planning/planners/rrt_planner.py @@ -315,7 +315,7 @@ def _simplify_path( return simplified -# ============= Result Helpers ============= +# Result Helpers def _create_success_result( diff --git a/dimos/manipulation/planning/spec/types.py b/dimos/manipulation/planning/spec/types.py index a38cc0da26..2683db7814 100644 --- a/dimos/manipulation/planning/spec/types.py +++ b/dimos/manipulation/planning/spec/types.py @@ -32,9 +32,6 @@ from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.sensor_msgs import JointState -# ============================================================================= -# Semantic ID Types (documentation only, not enforced at runtime) -# ============================================================================= RobotName: TypeAlias = str """User-facing robot name (e.g., 'left_arm', 'right_arm')""" @@ -45,19 +42,11 @@ JointPath: TypeAlias = "list[JointState]" """List of joint states forming a path (each waypoint has names + positions)""" -# ============================================================================= -# Numeric Array Types -# ============================================================================= Jacobian: TypeAlias = "NDArray[np.float64]" """6 x n Jacobian matrix (rows: [vx, vy, vz, wx, wy, wz])""" -# ============================================================================= -# Data Classes -# ============================================================================= - - @dataclass class Obstacle: """Obstacle specification for collision avoidance. diff --git a/dimos/manipulation/planning/world/drake_world.py b/dimos/manipulation/planning/world/drake_world.py index 2ab996f410..147e1e3ad3 100644 --- a/dimos/manipulation/planning/world/drake_world.py +++ b/dimos/manipulation/planning/world/drake_world.py @@ -124,8 +124,6 @@ def _call(self, fn: Any, *args: Any, **kwargs: Any) -> Any: return fn(*args, **kwargs) return self._executor.submit(fn, *args, **kwargs).result() - # --- Meshcat proxies --- - def SetObject(self, *args: Any, **kwargs: Any) -> Any: return self._call(self._inner.SetObject, *args, **kwargs) @@ -327,7 +325,7 @@ def get_joint_limits( np.full(n_joints, np.pi), ) - # ============= Obstacle Management ============= + # Obstacle Management def add_obstacle(self, obstacle: Obstacle) -> str: """Add an obstacle to the world.""" @@ -536,7 +534,7 @@ def clear_obstacles(self) -> None: for obs_id in obstacle_ids: self.remove_obstacle(obs_id) - # ============= Preview Robot Setup ============= + # Preview Robot Setup def _set_preview_colors(self) -> None: """Set all preview robot visual geometries to yellow/semi-transparent.""" @@ -565,7 +563,7 @@ def _remove_preview_collision_roles(self) -> None: for geom_id in self._plant.GetCollisionGeometriesForBody(body): self._scene_graph.RemoveRole(source_id, geom_id, Role.kProximity) - # ============= Lifecycle ============= + # Lifecycle def finalize(self) -> None: """Finalize world - locks robot topology, enables collision checking.""" @@ -683,7 +681,7 @@ def _exclude_body_pair(self, body1: Any, body2: Any) -> None: ) ) - # ============= Context Management ============= + # Context Management def get_live_context(self) -> Context: """Get the live context (mirrors current robot state). @@ -736,7 +734,7 @@ def sync_from_joint_state(self, robot_id: WorldRobotID, joint_state: JointState) # Calling ForcedPublish from the LCM callback thread blocks message processing. # Visualization can be updated via publish_to_meshcat() from non-callback contexts. - # ============= State Operations (context-based) ============= + # State Operations (context-based) def set_joint_state( self, ctx: Context, robot_id: WorldRobotID, joint_state: JointState @@ -782,7 +780,7 @@ def get_joint_state(self, ctx: Context, robot_id: WorldRobotID) -> JointState: positions = [float(full_positions[idx]) for idx in robot_data.joint_indices] return JointState(name=robot_data.config.joint_names, position=positions) - # ============= Collision Checking (context-based) ============= + # Collision Checking (context-based) def is_collision_free(self, ctx: Context, robot_id: WorldRobotID) -> bool: """Check if current configuration in context is collision-free.""" @@ -812,7 +810,7 @@ def get_min_distance(self, ctx: Context, robot_id: WorldRobotID) -> float: return float(min(pair.distance for pair in signed_distance_pairs)) - # ============= Collision Checking (context-free, for planning) ============= + # Collision Checking (context-free, for planning) def check_config_collision_free(self, robot_id: WorldRobotID, joint_state: JointState) -> bool: """Check if a joint state is collision-free (manages context internally). @@ -859,7 +857,7 @@ def check_edge_collision_free( return True - # ============= Forward Kinematics (context-based) ============= + # Forward Kinematics (context-based) def get_ee_pose(self, ctx: Context, robot_id: WorldRobotID) -> PoseStamped: """Get end-effector pose.""" @@ -944,7 +942,7 @@ def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float return J_reordered - # ============= Visualization ============= + # Visualization def get_visualization_url(self) -> str | None: """Get visualization URL if enabled.""" @@ -1029,7 +1027,7 @@ def close(self) -> None: if self._meshcat is not None: self._meshcat.close() - # ============= Direct Access (use with caution) ============= + # Direct Access (use with caution) @property def plant(self) -> MultibodyPlant: diff --git a/dimos/manipulation/test_manipulation_unit.py b/dimos/manipulation/test_manipulation_unit.py index 4aa232c74f..cfd6e35fda 100644 --- a/dimos/manipulation/test_manipulation_unit.py +++ b/dimos/manipulation/test_manipulation_unit.py @@ -30,10 +30,6 @@ from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint -# ============================================================================= -# Fixtures -# ============================================================================= - @pytest.fixture def robot_config(): @@ -103,11 +99,6 @@ def _make_module(): return module -# ============================================================================= -# Test State Machine -# ============================================================================= - - class TestStateMachine: """Test state transitions.""" @@ -167,11 +158,6 @@ def test_begin_planning_state_checks(self, robot_config): assert module._begin_planning() is None -# ============================================================================= -# Test Robot Selection -# ============================================================================= - - class TestRobotSelection: """Test robot selection logic.""" @@ -201,11 +187,6 @@ def test_multiple_robots_require_name(self, robot_config): assert result[0] == "left" -# ============================================================================= -# Test Joint Name Translation (for coordinator integration) -# ============================================================================= - - class TestJointNameTranslation: """Test trajectory joint name translation for coordinator.""" @@ -227,11 +208,6 @@ def test_mapping_translates_names(self, robot_config_with_mapping, simple_trajec assert len(result.points) == 2 # Points preserved -# ============================================================================= -# Test Execute Method -# ============================================================================= - - class TestExecute: """Test coordinator execution.""" @@ -288,11 +264,6 @@ def test_execute_rejected(self, robot_config, simple_trajectory): assert module._state == ManipulationState.FAULT -# ============================================================================= -# Test RobotModelConfig Mapping Helpers -# ============================================================================= - - class TestRobotModelConfigMapping: """Test RobotModelConfig joint name mapping helpers.""" diff --git a/dimos/memory/timeseries/base.py b/dimos/memory/timeseries/base.py index 0d88355b5b..2831836020 100644 --- a/dimos/memory/timeseries/base.py +++ b/dimos/memory/timeseries/base.py @@ -92,8 +92,6 @@ def _find_after(self, timestamp: float) -> tuple[float, T] | None: """Find the first (ts, data) strictly after the given timestamp.""" ... - # --- Collection API (built on abstract methods) --- - def __len__(self) -> int: return self._count() diff --git a/dimos/memory/timeseries/legacy.py b/dimos/memory/timeseries/legacy.py index 15a4ff90fa..a98b0baddf 100644 --- a/dimos/memory/timeseries/legacy.py +++ b/dimos/memory/timeseries/legacy.py @@ -232,8 +232,6 @@ def _find_after(self, timestamp: float) -> tuple[float, T] | None: return (ts, data) return None - # === Backward-compatible API (TimedSensorReplay/SensorReplay) === - @property def files(self) -> list[Path]: """Return list of pickle files (backward compatibility with SensorReplay).""" diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index a2f5b41cbf..11c90cda87 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -122,7 +122,7 @@ def _init_schema(self) -> None: conn.commit() - # ==================== Entity Operations ==================== + # Entity Operations def upsert_entity( self, @@ -216,7 +216,7 @@ def get_entities_by_time( for row in cursor.fetchall() ] - # ==================== Relation Operations ==================== + # Relation Operations def add_relation( self, @@ -290,7 +290,7 @@ def get_recent_relations(self, limit: int = 50) -> list[dict[str, Any]]: for row in cursor.fetchall() ] - # ==================== Distance Operations ==================== + # Distance Operations def add_distance( self, @@ -424,7 +424,7 @@ def get_nearby_entities( for row in cursor.fetchall() ] - # ==================== Neighborhood Query ==================== + # Neighborhood Query def get_entity_neighborhood( self, @@ -471,7 +471,7 @@ def get_entity_neighborhood( "num_hops": max_hops, } - # ==================== Stats / Summary ==================== + # Stats / Summary def get_stats(self) -> dict[str, Any]: conn = self._get_connection() @@ -491,7 +491,7 @@ def get_summary(self, recent_relations_limit: int = 5) -> dict[str, Any]: "recent_relations": self.get_recent_relations(limit=recent_relations_limit), } - # ==================== Bulk Save ==================== + # Bulk Save def save_window_data( self, @@ -608,7 +608,7 @@ def estimate_and_save_distances( except Exception as e: logger.warning(f"Failed to estimate distances: {e}", exc_info=True) - # ==================== Lifecycle ==================== + # Lifecycle def commit(self) -> None: if hasattr(self._local, "conn"): diff --git a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py index 7af13ad9c2..fc2c9c8a79 100644 --- a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py +++ b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py @@ -72,10 +72,6 @@ def __init__( self.stride_s = stride_s self.fps = fps - # ------------------------------------------------------------------ - # Ingest - # ------------------------------------------------------------------ - def set_start_time(self, wall_time: float) -> None: with self._lock: if self._video_start_wall_time is None: @@ -103,10 +99,6 @@ def add_frame(self, image: Image, wall_time: float) -> None: self._buffer.append(frame) self._frame_count += 1 - # ------------------------------------------------------------------ - # Window extraction - # ------------------------------------------------------------------ - def try_extract_window(self) -> list[Frame] | None: """Try to extract a window of frames. @@ -131,10 +123,6 @@ def mark_analysis_time(self, t: float) -> None: with self._lock: self._last_analysis_time = t - # ------------------------------------------------------------------ - # Accessors - # ------------------------------------------------------------------ - @property def frame_count(self) -> int: with self._lock: diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 7d01522417..8841d3a6b0 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -203,10 +203,6 @@ def __init__(self, **kwargs: Any) -> None: f"window={self.config.window_s}s, stride={self.config.stride_s}s" ) - # ------------------------------------------------------------------ - # VLM access (lazy) - # ------------------------------------------------------------------ - @property def vlm(self) -> VlModel[Any]: if self._vlm_raw is None: @@ -230,10 +226,6 @@ def _analyzer(self) -> WindowAnalyzer: ) return self.__analyzer - # ------------------------------------------------------------------ - # JSONL logging - # ------------------------------------------------------------------ - def _log_jsonl(self, record: dict[str, Any]) -> None: line = json.dumps(record, ensure_ascii=False) + "\n" # Write to per-run JSONL @@ -250,10 +242,6 @@ def _log_jsonl(self, record: dict[str, Any]) -> None: except Exception as e: logger.warning(f"persistent jsonl log failed: {e}") - # ------------------------------------------------------------------ - # Rerun visualization - # ------------------------------------------------------------------ - def _publish_entity_markers(self) -> None: """Publish entity positions as 3D markers for Rerun overlay on the map.""" if not self.config.visualize: @@ -288,10 +276,6 @@ def _publish_entity_markers(self) -> None: except Exception as e: logger.debug(f"entity marker publish error: {e}") - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - @rpc def start(self) -> None: super().start() @@ -374,10 +358,6 @@ def stop(self) -> None: logger.info("TemporalMemory stopped") - # ------------------------------------------------------------------ - # Core loop - # ------------------------------------------------------------------ - def _analyze_window(self) -> None: if self._stopped: return @@ -518,10 +498,6 @@ def _update_rolling_summary(self, w_end: float) -> None: ) logger.info(f"[temporal-memory] SUMMARY: {sr.summary_text[:300]}") - # ------------------------------------------------------------------ - # Query (agent skill) - # ------------------------------------------------------------------ - @skill def query(self, question: str) -> str: """Answer a question about the video stream using temporal memory and graph knowledge. @@ -611,10 +587,6 @@ def query(self, question: str) -> str: ) return qr.answer - # ------------------------------------------------------------------ - # RPC accessors (backward compat) - # ------------------------------------------------------------------ - @rpc def clear_history(self) -> bool: try: diff --git a/dimos/perception/experimental/temporal_memory/temporal_state.py b/dimos/perception/experimental/temporal_memory/temporal_state.py index 64914761b1..dfc440872d 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_state.py +++ b/dimos/perception/experimental/temporal_memory/temporal_state.py @@ -39,10 +39,6 @@ class TemporalState: _lock: threading.Lock = field(default_factory=threading.Lock, repr=False, compare=False) - # ------------------------------------------------------------------ - # Snapshot - # ------------------------------------------------------------------ - def snapshot(self) -> TemporalState: """Return a deep-copy snapshot (safe to read outside the lock).""" with self._lock: @@ -65,10 +61,6 @@ def to_dict(self) -> dict[str, Any]: "last_present": copy.deepcopy(self.last_present), } - # ------------------------------------------------------------------ - # Mutators - # ------------------------------------------------------------------ - def update_from_window( self, parsed: dict[str, Any], diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index abaa99dede..5b37b66770 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -64,11 +64,6 @@ def _make_image(value: int = 128, shape: tuple[int, ...] = (64, 64, 3)) -> Image return Image.from_numpy(data) -# ====================================================================== -# 1. FrameWindowAccumulator tests -# ====================================================================== - - class TestFrameWindowAccumulator: def test_bounded_buffer(self) -> None: acc = FrameWindowAccumulator(max_buffer_frames=5, window_s=1.0, stride_s=1.0, fps=1.0) @@ -125,11 +120,6 @@ def test_clear(self) -> None: assert acc.buffer_size == 0 -# ====================================================================== -# 2. TemporalState tests -# ====================================================================== - - class TestTemporalState: def test_update_and_snapshot(self) -> None: state = TemporalState(next_summary_at_s=10.0) @@ -226,11 +216,6 @@ def test_auto_add_referenced(self) -> None: assert "E2" in ids -# ====================================================================== -# 3. extract_time_window (regex-only) tests -# ====================================================================== - - class TestExtractTimeWindow: def test_keyword_patterns(self) -> None: assert extract_time_window("just now") == 60 @@ -248,11 +233,6 @@ def test_no_time_reference(self) -> None: assert extract_time_window("is there a person?") is None -# ====================================================================== -# 4. EntityGraphDB tests -# ====================================================================== - - class TestEntityGraphDB: @pytest.fixture def db(self, tmp_path: Path) -> EntityGraphDB: @@ -315,11 +295,6 @@ def test_stats(self, db: EntityGraphDB) -> None: assert "semantic_relations" not in stats -# ====================================================================== -# 5. Persistence test (new_memory flag) -# ====================================================================== - - class TestPersistence: def test_new_memory_clears_db(self, tmp_path: Path) -> None: db_dir = tmp_path / "memory" / "temporal" @@ -372,11 +347,6 @@ def test_persistent_memory_survives(self, tmp_path: Path) -> None: tm.stop() -# ====================================================================== -# 6. Per-run JSONL logging test -# ====================================================================== - - class TestJSONLLogging: def test_log_entries(self, tmp_path: Path) -> None: db_dir = tmp_path / "db" @@ -415,11 +385,6 @@ def test_log_entries(self, tmp_path: Path) -> None: tm.stop() -# ====================================================================== -# 7. Rerun visualization test -# ====================================================================== - - class TestEntityMarkers: def test_publish_entity_markers(self, tmp_path: Path) -> None: db_dir = tmp_path / "db" @@ -482,11 +447,6 @@ def test_markers_to_rerun(self) -> None: assert isinstance(archetype, rr.Points3D) -# ====================================================================== -# 8. WindowAnalyzer mock tests -# ====================================================================== - - class TestWindowAnalyzer: def test_analyze_window_calls_vlm(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer @@ -555,11 +515,6 @@ def test_answer_query(self) -> None: assert result.answer == "The answer is 42" -# ====================================================================== -# 9. Integration test with ModuleCoordinator -# ====================================================================== - - class VideoReplayModule(Module): """Module that replays synthetic video data for tests.""" diff --git a/dimos/perception/experimental/temporal_memory/window_analyzer.py b/dimos/perception/experimental/temporal_memory/window_analyzer.py index 70bfec8d74..cd01a3056d 100644 --- a/dimos/perception/experimental/temporal_memory/window_analyzer.py +++ b/dimos/perception/experimental/temporal_memory/window_analyzer.py @@ -79,10 +79,6 @@ def __init__( def vlm(self) -> VlModel[Any]: return self._vlm - # ------------------------------------------------------------------ - # VLM Call #1: Window analysis - # ------------------------------------------------------------------ - def analyze_window( self, frames: list[Frame], @@ -116,16 +112,8 @@ def analyze_window( parsed = tu.parse_window_response(raw, w_start, w_end, len(frames)) return AnalysisResult(parsed=parsed, raw_vlm_response=raw, w_start=w_start, w_end=w_end) - # ------------------------------------------------------------------ - # VLM Call #2: Distance estimation (delegated to EntityGraphDB) - # ------------------------------------------------------------------ - # Distance estimation is handled by EntityGraphDB.estimate_and_save_distances. # It's called from the orchestrator, not here. - # ------------------------------------------------------------------ - # VLM Call #3: Rolling summary - # ------------------------------------------------------------------ - def update_summary( self, latest_frame: Image, @@ -148,10 +136,6 @@ def update_summary( logger.error(f"summary update failed: {e}", exc_info=True) return None - # ------------------------------------------------------------------ - # VLM Call #5: Query answer - # ------------------------------------------------------------------ - def answer_query( self, question: str, diff --git a/dimos/protocol/pubsub/impl/shmpubsub.py b/dimos/protocol/pubsub/impl/shmpubsub.py index db0a91e579..883afcdcc0 100644 --- a/dimos/protocol/pubsub/impl/shmpubsub.py +++ b/dimos/protocol/pubsub/impl/shmpubsub.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# --------------------------------------------------------------------------- -# SharedMemory Pub/Sub over unified IPC channels (CPU/CUDA) -# --------------------------------------------------------------------------- from __future__ import annotations @@ -101,7 +98,7 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u # Lock for thread-safe publish buffer access self.publish_lock = threading.Lock() - # ----- init / lifecycle ------------------------------------------------- + # init / lifecycle def __init__( self, @@ -146,7 +143,7 @@ def stop(self) -> None: self._topics.clear() logger.debug("SharedMemory PubSub stopped.") - # ----- PubSub API (bytes on the wire) ---------------------------------- + # PubSub API (bytes on the wire) def publish(self, topic: str, message: bytes) -> None: if not isinstance(message, bytes | bytearray | memoryview): @@ -212,7 +209,7 @@ def _unsub() -> None: return _unsub - # ----- Capacity mgmt ---------------------------------------------------- + # Capacity mgmt def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[type-arg] """Change payload capacity (bytes) for a topic; returns new descriptor.""" @@ -229,7 +226,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[typ st.publish_buffer = np.zeros(new_shape, dtype=np.uint8) return desc # type: ignore[no-any-return] - # ----- Internals -------------------------------------------------------- + # Internals def _ensure_topic(self, topic: str) -> _TopicState: with self._lock: diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index fbf98d379e..29ed682f8d 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -54,11 +54,6 @@ def _open_shm_with_retry(name: str) -> SharedMemory: raise FileNotFoundError(f"SHM not found after {tries} retries: {name}") from last -# --------------------------- -# 1) Abstract interface -# --------------------------- - - class FrameChannel(ABC): """Single-slot 'freshest frame' IPC channel with a tiny control block. - Double-buffered to avoid torn reads. @@ -125,11 +120,6 @@ def _safe_unlink(name: str) -> None: pass -# --------------------------- -# 2) CPU shared-memory backend -# --------------------------- - - class CpuShmChannel(FrameChannel): def __init__( # type: ignore[no-untyped-def] self, @@ -300,11 +290,6 @@ def close(self) -> None: pass -# --------------------------- -# 3) Factories -# --------------------------- - - class CPU_IPC_Factory: """Creates/attaches CPU shared-memory channels.""" @@ -318,11 +303,6 @@ def attach(desc: dict) -> CpuShmChannel: # type: ignore[type-arg] return CpuShmChannel.attach(desc) # type: ignore[arg-type, no-any-return] -# --------------------------- -# 4) Runtime selector -# --------------------------- - - def make_frame_channel( # type: ignore[no-untyped-def] shape, dtype=np.uint8, prefer: str = "auto", device: int = 0 ) -> FrameChannel: diff --git a/dimos/protocol/service/system_configurator/base.py b/dimos/protocol/service/system_configurator/base.py index c221af890f..e5f65bdc18 100644 --- a/dimos/protocol/service/system_configurator/base.py +++ b/dimos/protocol/service/system_configurator/base.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -# ----------------------------- sudo helpers ----------------------------- +# sudo helpers @cache @@ -66,7 +66,7 @@ def _write_sysctl_int(name: str, value: int) -> None: sudo_run("sysctl", "-w", f"{name}={value}", check=True, text=True, capture_output=False) -# -------------------------- base class for system config checks/requirements -------------------------- +# base class for system config checks/requirements class SystemConfigurator(ABC): @@ -91,7 +91,7 @@ def fix(self) -> None: raise NotImplementedError -# ----------------------------- generic enforcement of system configs ----------------------------- +# generic enforcement of system configs def configure_system(checks: list[SystemConfigurator], check_only: bool = False) -> None: diff --git a/dimos/protocol/service/system_configurator/lcm.py b/dimos/protocol/service/system_configurator/lcm.py index 6599f97407..9e1b3e5c61 100644 --- a/dimos/protocol/service/system_configurator/lcm.py +++ b/dimos/protocol/service/system_configurator/lcm.py @@ -25,7 +25,7 @@ sudo_run, ) -# ------------------------------ specific checks: multicast ------------------------------ +# specific checks: multicast class MulticastConfiguratorLinux(SystemConfigurator): @@ -182,7 +182,7 @@ def fix(self) -> None: sudo_run(*self.add_route_cmd, check=True, text=True, capture_output=True) -# ------------------------------ specific checks: buffers ------------------------------ +# specific checks: buffers IDEAL_RMEM_SIZE = 67_108_864 # 64MB @@ -254,7 +254,7 @@ def fix(self) -> None: _write_sysctl_int(key, target) -# ------------------------------ specific checks: ulimit ------------------------------ +# specific checks: ulimit class MaxFileConfiguratorMacOS(SystemConfigurator): diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index a647c89c86..78085e2363 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -34,7 +34,7 @@ MulticastConfiguratorMacOS, ) -# ----------------------------- autoconf tests ----------------------------- +# autoconf tests class TestConfigureSystemForLcm: @@ -87,7 +87,7 @@ def test_logs_error_on_unsupported_system(self) -> None: assert "Windows" in mock_logger.error.call_args[0][0] -# ----------------------------- LCMConfig tests ----------------------------- +# LCMConfig tests class TestLCMConfig: @@ -103,7 +103,7 @@ def test_custom_url(self) -> None: assert config.url == custom_url -# ----------------------------- Topic tests ----------------------------- +# Topic tests class TestTopic: @@ -118,7 +118,7 @@ def test_str_with_lcm_type(self) -> None: assert str(topic) == "my_topic#TestMessage" -# ----------------------------- LCMService tests ----------------------------- +# LCMService tests class TestLCMService: diff --git a/dimos/protocol/service/test_system_configurator.py b/dimos/protocol/service/test_system_configurator.py index 62de2a61ea..1bd44aa5e2 100644 --- a/dimos/protocol/service/test_system_configurator.py +++ b/dimos/protocol/service/test_system_configurator.py @@ -37,7 +37,7 @@ _write_sysctl_int, ) -# ----------------------------- Helper function tests ----------------------------- +# Helper function tests class TestIsRootUser: @@ -122,7 +122,7 @@ def test_calls_sudo_run_with_correct_args(self) -> None: ) -# ----------------------------- configure_system tests ----------------------------- +# configure_system tests class MockConfigurator(SystemConfigurator): @@ -186,7 +186,7 @@ def test_exits_on_no_with_critical_check(self, mocker) -> None: assert exc_info.value.code == 1 -# ----------------------------- MulticastConfiguratorLinux tests ----------------------------- +# MulticastConfiguratorLinux tests class TestMulticastConfiguratorLinux: @@ -259,7 +259,7 @@ def test_fix_runs_needed_commands(self) -> None: assert mock_run.call_count == 2 -# ----------------------------- MulticastConfiguratorMacOS tests ----------------------------- +# MulticastConfiguratorMacOS tests class TestMulticastConfiguratorMacOS: @@ -311,7 +311,7 @@ def test_fix_runs_route_command(self) -> None: assert "224.0.0.0/4" in add_args -# ----------------------------- BufferConfiguratorLinux tests ----------------------------- +# BufferConfiguratorLinux tests class TestBufferConfiguratorLinux: @@ -354,7 +354,7 @@ def test_fix_writes_needed_values(self) -> None: mock_write.assert_called_once_with("net.core.rmem_max", IDEAL_RMEM_SIZE) -# ----------------------------- BufferConfiguratorMacOS tests ----------------------------- +# BufferConfiguratorMacOS tests class TestBufferConfiguratorMacOS: @@ -398,7 +398,7 @@ def test_fix_writes_needed_values(self) -> None: ) -# ----------------------------- MaxFileConfiguratorMacOS tests ----------------------------- +# MaxFileConfiguratorMacOS tests class TestMaxFileConfiguratorMacOS: @@ -489,7 +489,7 @@ def test_fix_raises_on_setrlimit_error(self) -> None: configurator.fix() -# ----------------------------- ClockSyncConfigurator tests ----------------------------- +# ClockSyncConfigurator tests class TestClockSyncConfigurator: diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py index 94f8b3726f..1fbf6266ef 100644 --- a/dimos/skills/skills.py +++ b/dimos/skills/skills.py @@ -30,12 +30,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# region SkillLibrary - class SkillLibrary: - # ==== Flat Skill Library ==== - def __init__(self) -> None: self.registered_skills: list[AbstractSkill] = [] self.class_skills: list[AbstractSkill] = [] @@ -111,8 +107,6 @@ def __contains__(self, skill: AbstractSkill) -> bool: def __getitem__(self, index): # type: ignore[no-untyped-def] return self.registered_skills[index] - # ==== Calling a Function ==== - _instances: dict[str, dict] = {} # type: ignore[type-arg] def create_instance(self, name: str, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -154,8 +148,6 @@ def call(self, name: str, **args): # type: ignore[no-untyped-def] logger.error(error_msg) return error_msg - # ==== Tools ==== - def get_tools(self) -> Any: tools_json = self.get_list_of_skills_as_json(list_of_skills=self.registered_skills) # print(f"{Colors.YELLOW_PRINT_COLOR}Tools JSON: {tools_json}{Colors.RESET_COLOR}") @@ -250,11 +242,6 @@ def terminate_skill(self, name: str): # type: ignore[no-untyped-def] return f"No running skill found with name: {name}" -# endregion SkillLibrary - -# region AbstractSkill - - class AbstractSkill(BaseModel): def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] print("Initializing AbstractSkill Class") @@ -289,7 +276,6 @@ def unregister_as_running(self, name: str, skill_library: SkillLibrary) -> None: """ skill_library.unregister_running_skill(name) - # ==== Tools ==== def get_tools(self) -> Any: tools_json = self.get_list_of_skills_as_json(list_of_skills=self._list_of_skills) # print(f"Tools JSON: {tools_json}") @@ -299,10 +285,6 @@ def get_list_of_skills_as_json(self, list_of_skills: list[AbstractSkill]) -> lis return list(map(pydantic_function_tool, list_of_skills)) # type: ignore[arg-type] -# endregion AbstractSkill - -# region Abstract Robot Skill - if TYPE_CHECKING: from dimos.robot.robot import Robot else: @@ -338,6 +320,3 @@ def __call__(self): # type: ignore[no-untyped-def] print( f"{Colors.BLUE_PRINT_COLOR}Robot Instance provided to Robot Skill: {self.__class__.__name__}{Colors.RESET_COLOR}" ) - - -# endregion Abstract Robot Skill diff --git a/dimos/stream/frame_processor.py b/dimos/stream/frame_processor.py index ab18400c88..c2db47dc23 100644 --- a/dimos/stream/frame_processor.py +++ b/dimos/stream/frame_processor.py @@ -154,8 +154,6 @@ def visualize_flow(self, flow): # type: ignore[no-untyped-def] rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) return rgb - # ============================== - def process_stream_edge_detection(self, frame_stream): # type: ignore[no-untyped-def] return frame_stream.pipe( ops.map(self.edge_detection), diff --git a/dimos/teleop/phone/phone_teleop_module.py b/dimos/teleop/phone/phone_teleop_module.py index f13842811b..cc55f1f180 100644 --- a/dimos/teleop/phone/phone_teleop_module.py +++ b/dimos/teleop/phone/phone_teleop_module.py @@ -69,10 +69,6 @@ class PhoneTeleopModule(Module[PhoneTeleopConfig]): # Output: velocity command to robot twist_output: Out[TwistStamped] - # ------------------------------------------------------------------------- - # Initialization - # ------------------------------------------------------------------------- - def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -98,10 +94,6 @@ def __init__(self, **kwargs: Any) -> None: self._setup_routes() - # ------------------------------------------------------------------------- - # Web Server Routes - # ------------------------------------------------------------------------- - def _setup_routes(self) -> None: """Register teleop routes on the embedded web server.""" @@ -133,10 +125,6 @@ async def websocket_endpoint(ws: WebSocket) -> None: except Exception: logger.exception("WebSocket error") - # ------------------------------------------------------------------------- - # Lifecycle - # ------------------------------------------------------------------------- - @rpc def start(self) -> None: super().start() @@ -149,10 +137,6 @@ def stop(self) -> None: self._stop_server() super().stop() - # ------------------------------------------------------------------------- - # Internal engage / disengage (assumes lock is held) - # ------------------------------------------------------------------------- - def _engage(self) -> bool: """Engage: capture current sensors as initial""" if self._current_sensors is None: @@ -169,10 +153,6 @@ def _disengage(self) -> None: self._initial_sensors = None logger.info("Phone teleop disengaged") - # ------------------------------------------------------------------------- - # WebSocket Message Decoders - # ------------------------------------------------------------------------- - def _on_sensors_bytes(self, data: bytes) -> None: """Decode raw LCM bytes into TwistStamped and update sensor state.""" msg = TwistStamped.lcm_decode(data) @@ -185,10 +165,6 @@ def _on_button_bytes(self, data: bytes) -> None: with self._lock: self._teleop_button = bool(msg.data) - # ------------------------------------------------------------------------- - # Embedded Web Server - # ------------------------------------------------------------------------- - def _start_server(self) -> None: """Start the embedded FastAPI server with HTTPS in a daemon thread.""" if self._web_server_thread is not None and self._web_server_thread.is_alive(): @@ -212,10 +188,6 @@ def _stop_server(self) -> None: self._web_server_thread = None logger.info("Phone teleop web server stopped") - # ------------------------------------------------------------------------- - # Control Loop - # ------------------------------------------------------------------------- - def _start_control_loop(self) -> None: if self._control_loop_thread is not None and self._control_loop_thread.is_alive(): return @@ -254,10 +226,6 @@ def _control_loop(self) -> None: if sleep_time > 0: self._stop_event.wait(sleep_time) - # ------------------------------------------------------------------------- - # Control Loop Internal Methods - # ------------------------------------------------------------------------- - def _handle_engage(self) -> None: """ Override to customize engagement logic. diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 5672a2bea0..ac86a0325f 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,10 +26,6 @@ from dimos.teleop.quest.quest_extensions import arm_teleop_module, visualizing_teleop_module from dimos.teleop.quest.quest_types import Buttons -# ----------------------------------------------------------------------------- -# Quest Teleop Blueprints -# ----------------------------------------------------------------------------- - # Arm teleop with press-and-hold engage arm_teleop = autoconnect( arm_teleop_module(), @@ -53,10 +49,6 @@ ) -# ----------------------------------------------------------------------------- -# Teleop wired to Coordinator (TeleopIK) -# ----------------------------------------------------------------------------- - # Single XArm7 teleop: right controller -> xarm7 # Usage: dimos run arm-teleop-xarm7 diff --git a/dimos/teleop/quest/quest_teleop_module.py b/dimos/teleop/quest/quest_teleop_module.py index 9beaf0da3e..3c8e6e9812 100644 --- a/dimos/teleop/quest/quest_teleop_module.py +++ b/dimos/teleop/quest/quest_teleop_module.py @@ -98,10 +98,6 @@ class QuestTeleopModule(Module[_Config]): right_controller_output: Out[PoseStamped] buttons: Out[Buttons] - # ------------------------------------------------------------------------- - # Initialization - # ------------------------------------------------------------------------- - def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -131,10 +127,6 @@ def __init__(self, **kwargs: Any) -> None: self._setup_routes() - # ------------------------------------------------------------------------- - # Web Server Routes - # ------------------------------------------------------------------------- - def _setup_routes(self) -> None: """Register teleop routes on the embedded web server.""" @@ -166,10 +158,6 @@ async def websocket_endpoint(ws: WebSocket) -> None: except Exception: logger.exception("WebSocket error") - # ------------------------------------------------------------------------- - # Lifecycle - # ------------------------------------------------------------------------- - @rpc def start(self) -> None: super().start() @@ -183,10 +171,6 @@ def stop(self) -> None: self._stop_server() super().stop() - # ------------------------------------------------------------------------- - # Internal engage/disengage (assumes lock is held) - # ------------------------------------------------------------------------- - def _engage(self, hand: Hand | None = None) -> bool: """Engage a hand. Assumes self._lock is held.""" hands = [hand] if hand is not None else list(Hand) @@ -219,10 +203,6 @@ def get_status(self) -> QuestTeleopStatus: buttons=Buttons.from_controllers(left, right), ) - # ------------------------------------------------------------------------- - # WebSocket Message Decoders - # ------------------------------------------------------------------------- - @staticmethod def _resolve_hand(frame_id: str) -> Hand: if frame_id == "left": @@ -253,10 +233,6 @@ def _on_joy_bytes(self, data: bytes) -> None: with self._lock: self._controllers[hand] = controller - # ------------------------------------------------------------------------- - # Embedded Web Server - # ------------------------------------------------------------------------- - def _start_server(self) -> None: """Start the embedded FastAPI server with HTTPS in a daemon thread.""" if self._web_server_thread is not None and self._web_server_thread.is_alive(): @@ -335,10 +311,6 @@ def _control_loop(self) -> None: if sleep_time > 0: self._stop_event.wait(sleep_time) - # ------------------------------------------------------------------------- - # Control Loop Internals - # ------------------------------------------------------------------------- - def _handle_engage(self) -> None: """Check for engage button press and update per-hand engage state. diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py new file mode 100644 index 0000000000..9523c0aae2 --- /dev/null +++ b/dimos/test_no_sections.py @@ -0,0 +1,143 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from dimos.constants import DIMOS_PROJECT_ROOT + +REPO_ROOT = str(DIMOS_PROJECT_ROOT) + +# Matches lines that are purely separator characters (=== or ---) with optional +# whitespace, e.g.: # ============= or # --------------- +SEPARATOR_LINE = re.compile(r"^\s*#\s*[-=]{10,}\s*$") + +# Matches section headers wrapped in separators, e.g.: +# # === My Section === or # ===== My Section ===== +INLINE_SECTION = re.compile(r"^\s*#\s*[-=]{3,}.+[-=]{3,}\s*$") + +# VS Code-style region markers +REGION_MARKER = re.compile(r"^\s*#\s*(region|endregion)\b") + +SCANNED_EXTENSIONS = { + ".py", + ".yml", + ".yaml", +} + +SCANNED_PREFIXES = { + "Dockerfile", +} + +IGNORED_DIRS = { + ".venv", + "venv", + "__pycache__", + "node_modules", + ".git", + "dist", + "build", + ".egg-info", + ".tox", + # third-party vendored code + "gtsam", +} + +# Lines that match section patterns but are actually programmatic / intentional. +# Each entry is (relative_path, line_substring) — if both match, the line is skipped. +WHITELIST = [ + # Sentinel marker used at runtime to detect already-converted Dockerfiles + ("dimos/core/docker_build.py", "DIMOS_SENTINEL"), +] + + +def _should_scan(path: str) -> bool: + basename = os.path.basename(path) + _, ext = os.path.splitext(basename) + if ext in SCANNED_EXTENSIONS: + return True + for prefix in SCANNED_PREFIXES: + if basename.startswith(prefix): + return True + return False + + +def _is_ignored_dir(dirpath: str) -> bool: + parts = dirpath.split(os.sep) + return bool(IGNORED_DIRS.intersection(parts)) + + +def _is_whitelisted(rel_path: str, line: str) -> bool: + for allowed_path, allowed_substr in WHITELIST: + if rel_path == allowed_path and allowed_substr in line: + return True + return False + + +def find_section_markers() -> list[tuple[str, int, str]]: + """Return a list of (file, line_number, line_text) for every section marker.""" + violations: list[tuple[str, int, str]] = [] + + for dirpath, dirnames, filenames in os.walk(REPO_ROOT): + # Prune ignored directories in-place + dirnames[:] = [d for d in dirnames if d not in IGNORED_DIRS] + + if _is_ignored_dir(dirpath): + continue + + rel_dir = os.path.relpath(dirpath, REPO_ROOT) + + for fname in filenames: + full_path = os.path.join(dirpath, fname) + rel_path = os.path.join(rel_dir, fname) + + if not _should_scan(full_path): + continue + + try: + with open(full_path, encoding="utf-8", errors="replace") as f: + for lineno, line in enumerate(f, start=1): + stripped = line.rstrip("\n") + if _is_whitelisted(rel_path, stripped): + continue + if ( + SEPARATOR_LINE.match(stripped) + or INLINE_SECTION.match(stripped) + or REGION_MARKER.match(stripped) + ): + violations.append((rel_path, lineno, stripped)) + except (OSError, UnicodeDecodeError): + continue + + return violations + + +def test_no_section_markers(): + """ + Fail if any file contains section-style comment markers. + + If a file is too complicated to be understood without sections, then the + sections should be files. We don't need "subfiles". + """ + violations = find_section_markers() + if violations: + report_lines = [ + f"Found {len(violations)} section marker(s). " + "If a file is too complicated to be understood without sections, " + 'then the sections should be files. We don\'t need "subfiles".', + "", + ] + for path, lineno, text in violations: + report_lines.append(f" {path}:{lineno}: {text.strip()}") + raise AssertionError("\n".join(report_lines)) diff --git a/dimos/utils/cli/dtop.py b/dimos/utils/cli/dtop.py index fa463c15d6..64529a6bc3 100644 --- a/dimos/utils/cli/dtop.py +++ b/dimos/utils/cli/dtop.py @@ -40,10 +40,6 @@ if TYPE_CHECKING: from collections.abc import Callable -# --------------------------------------------------------------------------- -# Color helpers -# --------------------------------------------------------------------------- - def _heat(ratio: float) -> str: """Map 0..1 ratio to a cyan → yellow → red gradient.""" @@ -96,11 +92,6 @@ def _rel_style(value: float, lo: float, hi: float) -> str: return _heat(min((value - lo) / (hi - lo), 1.0)) -# --------------------------------------------------------------------------- -# Metric formatters (plain strings — color applied separately via _rel_style) -# --------------------------------------------------------------------------- - - def _fmt_pct(v: float) -> str: return f"{v:3.0f}%" @@ -128,11 +119,6 @@ def _fmt_io(v: float) -> str: return f"{v / 1048576:.0f} MB" -# --------------------------------------------------------------------------- -# Metric definitions — add a tuple here to add a new field -# (label, dict_key, format_fn) -# --------------------------------------------------------------------------- - _LINE1: list[tuple[str, str, Callable[[float], str]]] = [ ("CPU", "cpu_percent", _fmt_pct), ("PSS", "pss", _fmt_mem), @@ -162,11 +148,6 @@ def _compute_ranges(data_dicts: list[dict[str, Any]]) -> dict[str, tuple[float, return ranges -# --------------------------------------------------------------------------- -# App -# --------------------------------------------------------------------------- - - class ResourceSpyApp(App[None]): CSS_PATH = "dimos.tcss" @@ -367,10 +348,6 @@ def _make_lines( return [line1, line2] -# --------------------------------------------------------------------------- -# Preview -# --------------------------------------------------------------------------- - _PREVIEW_DATA: dict[str, Any] = { "coordinator": { "cpu_percent": 12.3, diff --git a/dimos/utils/simple_controller.py b/dimos/utils/simple_controller.py index f95350552c..c8a6ade19d 100644 --- a/dimos/utils/simple_controller.py +++ b/dimos/utils/simple_controller.py @@ -20,9 +20,6 @@ def normalize_angle(angle: float): # type: ignore[no-untyped-def] return math.atan2(math.sin(angle), math.cos(angle)) -# ---------------------------- -# PID Controller Class -# ---------------------------- class PIDController: def __init__( # type: ignore[no-untyped-def] self, @@ -120,9 +117,6 @@ def _apply_deadband_compensation(self, error): # type: ignore[no-untyped-def] return error -# ---------------------------- -# Visual Servoing Controller Class -# ---------------------------- class VisualServoingController: def __init__(self, distance_pid_params, angle_pid_params) -> None: # type: ignore[no-untyped-def] """ diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index e55c8b20f3..9970fc5912 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -132,11 +132,6 @@ def test_pull_dir() -> None: assert sha256 == expected_hash -# ============================================================================ -# LfsPath Tests -# ============================================================================ - - def test_lfs_path_lazy_creation() -> None: """Test that creating LfsPath doesn't trigger download.""" lfs_path = LfsPath("test_data_file") diff --git a/docker/navigation/.env.hardware b/docker/navigation/.env.hardware index 234e58545c..fc0e34581e 100644 --- a/docker/navigation/.env.hardware +++ b/docker/navigation/.env.hardware @@ -1,16 +1,8 @@ # Hardware Configuration Environment Variables # Copy this file to .env and customize for your hardware setup -# ============================================ -# NVIDIA GPU Support -# ============================================ -# Set the Docker runtime to nvidia for GPU support (it's runc by default) #DOCKER_RUNTIME=nvidia -# ============================================ -# ROS Configuration -# ============================================ -# ROS domain ID for multi-robot setups ROS_DOMAIN_ID=42 # Robot configuration ('mechanum_drive', 'unitree/unitree_g1', 'unitree/unitree_g1', etc) @@ -21,10 +13,6 @@ ROBOT_CONFIG_PATH=mechanum_drive # This can be found in the unitree app under Device settings or via network scan ROBOT_IP= -# ============================================ -# Mid-360 Lidar Configuration -# ============================================ -# Network interface connected to the lidar (e.g., eth0, enp0s3) # Find with: ip addr show LIDAR_INTERFACE=eth0 @@ -43,24 +31,12 @@ LIDAR_GATEWAY=192.168.1.1 # LIDAR_IP=192.168.123.120 # FOR UNITREE G1 EDU LIDAR_IP=192.168.1.116 -# ============================================ -# Motor Controller Configuration -# ============================================ -# Serial device for motor controller # Check with: ls /dev/ttyACM* or ls /dev/ttyUSB* MOTOR_SERIAL_DEVICE=/dev/ttyACM0 -# ============================================ -# Network Communication (for base station) -# ============================================ -# Enable WiFi buffer optimization for data transmission # Set to true if using wireless base station ENABLE_WIFI_BUFFER=false -# ============================================ -# Unitree Robot Configuration -# ============================================ -# Enable Unitree WebRTC control (for Go2, G1) #USE_UNITREE=true # Unitree robot IP address @@ -69,10 +45,6 @@ UNITREE_IP=192.168.12.1 # Unitree connection method (LocalAP or Ethernet) UNITREE_CONN=LocalAP -# ============================================ -# Navigation Options -# ============================================ -# Enable route planner (FAR planner for goal navigation) USE_ROUTE_PLANNER=false # Enable RViz visualization @@ -83,10 +55,6 @@ USE_RVIZ=false # The system will load: MAP_PATH.pcd for SLAM, MAP_PATH_tomogram.pickle for PCT planner MAP_PATH= -# ============================================ -# Device Group IDs -# ============================================ -# Group ID for /dev/input devices (joystick) # Find with: getent group input | cut -d: -f3 INPUT_GID=995 @@ -94,8 +62,4 @@ INPUT_GID=995 # Find with: getent group dialout | cut -d: -f3 DIALOUT_GID=20 -# ============================================ -# Display Configuration -# ============================================ -# X11 display (usually auto-detected) # DISPLAY=:0 diff --git a/docker/navigation/Dockerfile b/docker/navigation/Dockerfile index fa51fd621c..dc2ce54f39 100644 --- a/docker/navigation/Dockerfile +++ b/docker/navigation/Dockerfile @@ -1,39 +1,23 @@ -# ============================================================================= -# DimOS Navigation Docker Image -# ============================================================================= -# # Multi-stage build for ROS 2 navigation with SLAM support. # Includes both arise_slam and FASTLIO2 - select at runtime via LOCALIZATION_METHOD. -# # Supported configurations: # - ROS distributions: humble, jazzy # - SLAM methods: arise_slam (default), fastlio (set LOCALIZATION_METHOD=fastlio) -# # Build: # ./build.sh --humble # Build for ROS 2 Humble # ./build.sh --jazzy # Build for ROS 2 Jazzy -# # Run: # ./start.sh --hardware --route-planner # Uses arise_slam # LOCALIZATION_METHOD=fastlio ./start.sh --hardware --route-planner # Uses FASTLIO2 -# -# ============================================================================= # Build argument for ROS distribution (default: humble) ARG ROS_DISTRO=humble ARG TARGETARCH -# ----------------------------------------------------------------------------- -# Platform-specific base images # - amd64: Use osrf/ros desktop-full (includes Gazebo, full GUI) -# - arm64: Use ros-base (desktop-full not available for ARM) -# ----------------------------------------------------------------------------- FROM osrf/ros:${ROS_DISTRO}-desktop-full AS base-amd64 FROM ros:${ROS_DISTRO}-ros-base AS base-arm64 -# ----------------------------------------------------------------------------- -# STAGE 1: Build Stage - compile all C++ dependencies -# ----------------------------------------------------------------------------- FROM base-${TARGETARCH} AS builder ARG ROS_DISTRO @@ -200,9 +184,6 @@ RUN /bin/bash -c "source /opt/ros/${ROS_DISTRO}/setup.bash && \ echo 'Building with both arise_slam and FASTLIO2' && \ colcon build --cmake-args -DCMAKE_BUILD_TYPE=Release" -# ----------------------------------------------------------------------------- -# STAGE 2: Runtime Stage - minimal image for running -# ----------------------------------------------------------------------------- ARG ROS_DISTRO ARG TARGETARCH FROM base-${TARGETARCH} AS runtime diff --git a/docker/navigation/docker-compose.dev.yml b/docker/navigation/docker-compose.dev.yml index defbdae846..537e00581d 100644 --- a/docker/navigation/docker-compose.dev.yml +++ b/docker/navigation/docker-compose.dev.yml @@ -1,13 +1,6 @@ -# ============================================================================= -# DEVELOPMENT OVERRIDES - Mount source for live editing -# ============================================================================= -# # Usage: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -# # This file adds development-specific volume mounts for editing ROS configs # without rebuilding the image. -# -# ============================================================================= services: dimos_simulation: diff --git a/docs/capabilities/manipulation/adding_a_custom_arm.md b/docs/capabilities/manipulation/adding_a_custom_arm.md index 2b435a50fe..3e931a7f73 100644 --- a/docs/capabilities/manipulation/adding_a_custom_arm.md +++ b/docs/capabilities/manipulation/adding_a_custom_arm.md @@ -116,9 +116,6 @@ class YourArmAdapter: self._sdk: YourArmSDK | None = None self._control_mode: ControlMode = ControlMode.POSITION - # ========================================================================= - # Connection - # ========================================================================= def connect(self) -> bool: """Connect to hardware. Returns True on success.""" @@ -144,9 +141,6 @@ class YourArmAdapter: """Check if connected.""" return self._sdk is not None and self._sdk.is_alive() - # ========================================================================= - # Info - # ========================================================================= def get_info(self) -> ManipulatorInfo: """Get manipulator info (vendor, model, DOF).""" @@ -173,9 +167,6 @@ class YourArmAdapter: velocity_max=[math.pi] * self._dof, # rad/s ) - # ========================================================================= - # Control Mode - # ========================================================================= def set_control_mode(self, mode: ControlMode) -> bool: """Set control mode. @@ -206,9 +197,6 @@ class YourArmAdapter: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= def read_joint_positions(self) -> list[float]: """Read current joint positions in radians. @@ -262,9 +250,6 @@ class YourArmAdapter: return 0, "" return code, f"YourArm error {code}" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= def write_joint_positions( self, @@ -300,9 +285,6 @@ class YourArmAdapter: return False return self._sdk.emergency_stop() - # ========================================================================= - # Servo Control - # ========================================================================= def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" @@ -322,10 +304,6 @@ class YourArmAdapter: return False return self._sdk.clear_errors() - # ========================================================================= - # Optional: Cartesian Control - # Return None/False if not supported by your arm. - # ========================================================================= def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose. @@ -343,9 +321,6 @@ class YourArmAdapter: """Command end-effector pose. Return False if not supported.""" return False - # ========================================================================= - # Optional: Gripper - # ========================================================================= def read_gripper_position(self) -> float | None: """Read gripper position in meters. Return None if no gripper.""" @@ -355,9 +330,6 @@ class YourArmAdapter: """Command gripper position in meters. Return False if no gripper.""" return False - # ========================================================================= - # Optional: Force/Torque Sensor - # ========================================================================= def read_force_torque(self) -> list[float] | None: """Read F/T sensor data [fx, fy, fz, tx, ty, tz]. None if no sensor.""" @@ -470,9 +442,6 @@ from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.transport import LCMTransport from dimos.msgs.sensor_msgs import JointState -# ============================================================================= -# Coordinator Blueprints -# ============================================================================= # YourArm (6-DOF) — real hardware coordinator_yourarm = control_coordinator( @@ -589,9 +558,6 @@ def _make_yourarm_config( Add this to your `dimos/robot/yourarm/blueprints.py` alongside the coordinator blueprint: ```python -# ============================================================================= -# Planner Blueprints (requires URDF) -# ============================================================================= yourarm_planner = manipulation_module( robots=[_make_yourarm_config("arm", joint_prefix="arm_", coordinator_task="traj_arm")], From c89ee5bd8e5e673f34e70b871f3fc148bba4921d Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Sat, 14 Mar 2026 08:34:08 +0200 Subject: [PATCH 09/42] fix(imports): remove dunder init (#1545) * fix(imports): remove dunder init * fix import --- dimos/__init__.py | 0 dimos/agents/agent.py | 2 +- dimos/agents/demo_agent.py | 2 +- dimos/agents/mcp/__init__.py | 0 dimos/agents/mcp/test_mcp_client.py | 2 +- dimos/agents/skills/demo_robot.py | 2 +- .../skills/google_maps_skill_container.py | 2 +- dimos/agents/skills/gps_nav_skill.py | 2 +- dimos/agents/skills/navigation.py | 7 +- dimos/agents/skills/osm.py | 2 +- dimos/agents/skills/person_follow.py | 6 +- .../test_google_maps_skill_container.py | 4 +- dimos/agents/skills/test_gps_nav_skills.py | 2 +- dimos/agents/skills/test_navigation.py | 4 +- dimos/agents/test_agent.py | 2 +- dimos/agents/vlm_agent.py | 2 +- dimos/agents/vlm_stream_tester.py | 2 +- dimos/agents_deprecated/__init__.py | 0 dimos/agents_deprecated/memory/__init__.py | 0 dimos/agents_deprecated/modules/__init__.py | 15 -- dimos/agents_deprecated/modules/base.py | 4 +- .../modules/gateway/__init__.py | 20 --- .../modules/gateway/utils.py | 156 ------------------ .../prompt_builder/__init__.py | 0 dimos/agents_deprecated/tokenizer/__init__.py | 0 dimos/control/__init__.py | 83 ---------- dimos/control/blueprints.py | 5 +- dimos/control/coordinator.py | 21 ++- dimos/control/examples/cartesian_ik_jogger.py | 4 +- dimos/control/task.py | 3 +- dimos/control/tasks/__init__.py | 49 ------ dimos/control/tasks/cartesian_ik_task.py | 3 +- dimos/control/tasks/teleop_task.py | 3 +- dimos/control/tasks/trajectory_task.py | 3 +- dimos/control/test_control.py | 3 +- dimos/control/tick_loop.py | 2 +- dimos/core/__init__.py | 0 dimos/core/blueprints.py | 3 +- dimos/core/docker_runner.py | 2 +- dimos/core/introspection/__init__.py | 20 --- .../core/introspection/blueprint/__init__.py | 24 --- dimos/core/introspection/module/__init__.py | 45 ----- dimos/core/module.py | 12 +- dimos/core/resource_monitor/__init__.py | 17 -- dimos/core/resource_monitor/stats.py | 2 +- dimos/core/rpc_client.py | 3 +- dimos/core/test_blueprints.py | 2 +- dimos/core/test_core.py | 4 +- dimos/core/test_stream.py | 2 +- dimos/core/test_worker.py | 2 +- dimos/core/testing.py | 6 +- dimos/e2e_tests/conftest.py | 3 +- dimos/e2e_tests/lcm_spy.py | 4 +- dimos/e2e_tests/test_control_coordinator.py | 6 +- dimos/e2e_tests/test_simulation_module.py | 4 +- dimos/exceptions/__init__.py | 0 dimos/hardware/__init__.py | 0 dimos/hardware/drive_trains/__init__.py | 15 -- .../drive_trains/flowbase/__init__.py | 15 -- dimos/hardware/drive_trains/mock/__init__.py | 30 ---- dimos/hardware/end_effectors/__init__.py | 17 -- dimos/hardware/manipulators/__init__.py | 51 ------ dimos/hardware/manipulators/mock/__init__.py | 28 ---- dimos/hardware/manipulators/piper/__init__.py | 26 --- dimos/hardware/manipulators/registry.py | 17 +- dimos/hardware/manipulators/spec.py | 4 +- dimos/hardware/manipulators/xarm/__init__.py | 26 --- .../camera/gstreamer/gstreamer_camera.py | 2 +- .../gstreamer/gstreamer_camera_test_script.py | 6 +- dimos/hardware/sensors/camera/module.py | 4 +- .../sensors/camera/realsense/__init__.py | 43 ----- .../sensors/camera/realsense/camera.py | 6 +- dimos/hardware/sensors/camera/spec.py | 5 +- dimos/hardware/sensors/camera/webcam.py | 4 +- dimos/hardware/sensors/camera/zed/camera.py | 6 +- .../camera/zed/{__init__.py => compat.py} | 2 +- dimos/hardware/sensors/camera/zed/test_zed.py | 2 +- dimos/hardware/sensors/fake_zed_module.py | 14 +- dimos/hardware/sensors/lidar/__init__.py | 0 .../sensors/lidar/fastlio2/__init__.py | 0 .../hardware/sensors/lidar/livox/__init__.py | 0 dimos/manipulation/__init__.py | 37 ----- dimos/manipulation/blueprints.py | 11 +- dimos/manipulation/control/__init__.py | 48 ------ .../control/coordinator_client.py | 2 +- .../control/dual_trajectory_setter.py | 4 +- .../control/servo_control/__init__.py | 32 ---- .../cartesian_motion_controller.py | 10 +- dimos/manipulation/control/target_setter.py | 4 +- .../control/trajectory_controller/__init__.py | 31 ---- .../joint_trajectory_controller.py | 7 +- .../control/trajectory_controller/spec.py | 7 +- .../manipulation/control/trajectory_setter.py | 4 +- dimos/manipulation/grasping/__init__.py | 30 ---- dimos/manipulation/grasping/demo_grasping.py | 4 +- .../manipulation/grasping/graspgen_module.py | 6 +- dimos/manipulation/grasping/grasping.py | 4 +- dimos/manipulation/manipulation_module.py | 33 ++-- dimos/manipulation/pick_and_place_module.py | 8 +- dimos/manipulation/planning/__init__.py | 84 ---------- .../planning/examples/__init__.py | 17 -- .../planning/examples/manipulation_client.py | 8 +- dimos/manipulation/planning/factory.py | 6 +- .../planning/kinematics/__init__.py | 51 ------ .../kinematics/drake_optimization_ik.py | 9 +- .../planning/kinematics/jacobian_ik.py | 11 +- .../planning/kinematics/pinocchio_ik.py | 3 +- .../manipulation/planning/monitor/__init__.py | 63 ------- .../planning/monitor/world_monitor.py | 14 +- .../monitor/world_obstacle_monitor.py | 13 +- .../planning/monitor/world_state_monitor.py | 4 +- .../planning/planners/__init__.py | 41 ----- .../planning/planners/rrt_planner.py | 12 +- dimos/manipulation/planning/spec/__init__.py | 51 ------ dimos/manipulation/planning/spec/config.py | 2 +- .../planning/spec/{types.py => models.py} | 4 +- dimos/manipulation/planning/spec/protocols.py | 6 +- .../planning/trajectory_generator/__init__.py | 25 --- .../joint_trajectory_generator.py | 3 +- .../planning/trajectory_generator/spec.py | 2 +- dimos/manipulation/planning/utils/__init__.py | 51 ------ .../planning/utils/kinematics_utils.py | 2 +- .../manipulation/planning/utils/path_utils.py | 5 +- dimos/manipulation/planning/world/__init__.py | 27 --- .../planning/world/drake_world.py | 17 +- .../manipulation/test_manipulation_module.py | 9 +- dimos/manipulation/test_manipulation_unit.py | 9 +- dimos/mapping/__init__.py | 0 dimos/mapping/costmapper.py | 4 +- dimos/mapping/google_maps/google_maps.py | 4 +- .../google_maps/{types.py => models.py} | 0 dimos/mapping/google_maps/test_google_maps.py | 2 +- dimos/mapping/{types.py => models.py} | 0 dimos/mapping/occupancy/path_mask.py | 2 +- dimos/mapping/occupancy/path_resampling.py | 7 +- dimos/mapping/occupancy/test_path_mask.py | 4 +- .../mapping/occupancy/test_path_resampling.py | 2 +- dimos/mapping/occupancy/visualizations.py | 2 +- dimos/mapping/occupancy/visualize_path.py | 2 +- dimos/mapping/osm/__init__.py | 0 dimos/mapping/osm/current_location_map.py | 2 +- dimos/mapping/osm/osm.py | 4 +- dimos/mapping/osm/query.py | 2 +- dimos/mapping/osm/test_osm.py | 2 +- dimos/mapping/pointclouds/demo.py | 4 +- dimos/mapping/pointclouds/occupancy.py | 4 +- dimos/mapping/pointclouds/test_occupancy.py | 2 +- .../pointclouds/test_occupancy_speed.py | 2 +- dimos/mapping/test_voxels.py | 2 +- dimos/mapping/utils/distance.py | 2 +- dimos/mapping/voxels.py | 4 +- dimos/memory/embedding.py | 5 +- dimos/memory/test_embedding.py | 4 +- dimos/memory/timeseries/__init__.py | 41 ----- dimos/models/__init__.py | 0 dimos/models/base.py | 2 +- dimos/models/embedding/__init__.py | 30 ---- dimos/models/embedding/base.py | 2 +- dimos/models/embedding/clip.py | 2 +- dimos/models/embedding/mobileclip.py | 2 +- dimos/models/embedding/test_embedding.py | 2 +- dimos/models/embedding/treid.py | 2 +- dimos/models/segmentation/edge_tam.py | 6 +- dimos/models/vl/__init__.py | 13 -- dimos/models/vl/base.py | 16 +- dimos/models/vl/florence.py | 2 +- dimos/models/vl/moondream.py | 6 +- dimos/models/vl/moondream_hosted.py | 6 +- dimos/models/vl/openai.py | 2 +- dimos/models/vl/qwen.py | 2 +- dimos/models/vl/test_base.py | 4 +- dimos/models/vl/test_captioner.py | 2 +- dimos/models/vl/test_vlm.py | 8 +- dimos/msgs/__init__.py | 4 - dimos/msgs/foxglove_msgs/__init__.py | 3 - dimos/msgs/geometry_msgs/Transform.py | 2 +- dimos/msgs/geometry_msgs/__init__.py | 38 ----- dimos/msgs/geometry_msgs/test_PoseStamped.py | 2 +- dimos/msgs/geometry_msgs/test_Transform.py | 6 +- dimos/msgs/geometry_msgs/test_Twist.py | 4 +- dimos/msgs/geometry_msgs/test_publish.py | 2 +- dimos/msgs/helpers.py | 5 +- dimos/msgs/nav_msgs/OccupancyGrid.py | 3 +- dimos/msgs/nav_msgs/__init__.py | 9 - dimos/msgs/nav_msgs/test_OccupancyGrid.py | 6 +- dimos/msgs/sensor_msgs/Imu.py | 3 +- dimos/msgs/sensor_msgs/PointCloud2.py | 3 +- dimos/msgs/sensor_msgs/__init__.py | 20 --- dimos/msgs/sensor_msgs/test_PointCloud2.py | 4 +- dimos/msgs/sensor_msgs/test_image.py | 2 +- dimos/msgs/std_msgs/__init__.py | 21 --- dimos/msgs/std_msgs/test_header.py | 2 +- dimos/msgs/tf2_msgs/__init__.py | 17 -- dimos/msgs/tf2_msgs/test_TFMessage.py | 6 +- dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py | 6 +- dimos/msgs/trajectory_msgs/__init__.py | 30 ---- dimos/msgs/vision_msgs/__init__.py | 15 -- dimos/navigation/base.py | 2 +- dimos/navigation/bbox_navigation.py | 6 +- dimos/navigation/demo_ros_navigation.py | 4 +- .../frontier_exploration/__init__.py | 3 - .../test_wavefront_frontier_goal_selector.py | 8 +- .../navigation/frontier_exploration/utils.py | 4 +- .../wavefront_frontier_goal_selector.py | 5 +- .../replanning_a_star/controllers.py | 3 +- .../replanning_a_star/global_planner.py | 2 +- .../replanning_a_star/goal_validator.py | 4 +- .../replanning_a_star/local_planner.py | 5 +- .../replanning_a_star/min_cost_astar.py | 7 +- dimos/navigation/replanning_a_star/module.py | 7 +- .../replanning_a_star/path_clearance.py | 2 +- .../replanning_a_star/path_distancer.py | 2 +- .../replanning_a_star/test_goal_validator.py | 2 +- dimos/navigation/rosnav.py | 35 ++-- dimos/navigation/visual/query.py | 2 +- .../visual_servoing/detection_navigation.py | 12 +- .../visual_servoing/visual_servoing_2d.py | 5 +- dimos/perception/__init__.py | 0 dimos/perception/common/__init__.py | 81 --------- dimos/perception/common/utils.py | 8 +- .../demo_object_scene_registration.py | 4 +- dimos/perception/detection/__init__.py | 10 -- dimos/perception/detection/conftest.py | 26 +-- .../detection/detectors/__init__.py | 8 - .../detection/detectors/{types.py => base.py} | 4 +- .../detection/detectors/conftest.py | 2 +- .../detectors/person/test_person_detectors.py | 3 +- .../detection/detectors/person/yolo.py | 6 +- .../detectors/test_bbox_detectors.py | 5 +- dimos/perception/detection/detectors/yolo.py | 6 +- dimos/perception/detection/detectors/yoloe.py | 6 +- dimos/perception/detection/module2D.py | 18 +- dimos/perception/detection/module3D.py | 20 ++- dimos/perception/detection/moduleDB.py | 12 +- dimos/perception/detection/objectDB.py | 4 +- dimos/perception/detection/person_tracker.py | 11 +- dimos/perception/detection/reid/__init__.py | 13 -- .../detection/reid/embedding_id_system.py | 2 +- dimos/perception/detection/reid/module.py | 6 +- .../reid/test_embedding_id_system.py | 2 +- .../perception/detection/reid/test_module.py | 4 +- dimos/perception/detection/type/__init__.py | 36 ---- .../detection/type/detection2d/__init__.py | 0 .../detection/type/detection2d/base.py | 4 +- .../detection/type/detection2d/bbox.py | 6 +- .../type/detection2d/imageDetections2D.py | 4 +- .../detection/type/detection2d/person.py | 2 +- .../detection/type/detection2d/point.py | 6 +- .../detection/type/detection2d/seg.py | 2 +- .../detection2d/test_imageDetections2D.py | 2 +- .../detection/type/detection2d/test_person.py | 2 +- .../detection/type/detection3d/__init__.py | 37 ----- .../detection/type/detection3d/base.py | 2 +- .../detection/type/detection3d/bbox.py | 10 +- .../detection/type/detection3d/object.py | 13 +- .../detection/type/detection3d/pointcloud.py | 6 +- .../type/detection3d/pointcloud_filters.py | 4 +- .../detection/type/imageDetections.py | 6 +- .../detection/type/test_object3d.py | 2 +- dimos/perception/experimental/__init__.py | 15 -- .../temporal_memory/clip_filter.py | 2 +- .../temporal_memory/entity_graph_db.py | 10 +- .../frame_window_accumulator.py | 2 +- .../temporal_memory/temporal_memory.py | 12 +- .../temporal_utils/__init__.py | 46 ------ .../test_temporal_memory_module.py | 10 +- .../temporal_memory/window_analyzer.py | 20 ++- dimos/perception/object_scene_registration.py | 14 +- dimos/perception/object_tracker.py | 19 ++- dimos/perception/object_tracker_2d.py | 6 +- dimos/perception/object_tracker_3d.py | 14 +- dimos/perception/perceive_loop_skill.py | 3 +- dimos/perception/spatial_perception.py | 8 +- dimos/perception/test_spatial_memory.py | 2 +- .../perception/test_spatial_memory_module.py | 6 +- dimos/protocol/__init__.py | 0 .../encode/{__init__.py => encoder.py} | 14 ++ dimos/protocol/pubsub/__init__.py | 9 - dimos/protocol/pubsub/encoders.py | 4 +- dimos/protocol/pubsub/impl/__init__.py | 6 - dimos/protocol/pubsub/impl/lcmpubsub.py | 4 +- dimos/protocol/pubsub/impl/memory.py | 2 +- dimos/protocol/pubsub/impl/rospubsub.py | 2 +- .../pubsub/impl/rospubsub_conversion.py | 2 +- dimos/protocol/pubsub/impl/test_lcmpubsub.py | 4 +- dimos/protocol/pubsub/impl/test_rospubsub.py | 2 +- dimos/protocol/pubsub/test_pattern_sub.py | 4 +- dimos/protocol/pubsub/test_spec.py | 2 +- dimos/protocol/rpc/__init__.py | 18 -- dimos/protocol/rpc/test_lcmrpc.py | 2 +- dimos/protocol/service/__init__.py | 9 - dimos/protocol/service/lcmservice.py | 3 +- .../{__init__.py => lcm_config.py} | 39 +---- dimos/protocol/service/test_lcmservice.py | 16 +- .../service/test_system_configurator.py | 20 +-- dimos/protocol/tf/__init__.py | 17 -- dimos/protocol/tf/test_tf.py | 7 +- dimos/protocol/tf/tf.py | 5 +- dimos/protocol/tf/tflcmcpp.py | 2 +- dimos/robot/__init__.py | 0 dimos/robot/drone/__init__.py | 26 --- dimos/robot/drone/blueprints/__init__.py | 26 --- .../drone/blueprints/agentic/__init__.py | 5 - .../robot/drone/blueprints/basic/__init__.py | 5 - dimos/robot/drone/camera_module.py | 6 +- dimos/robot/drone/connection_module.py | 10 +- dimos/robot/drone/dji_video_stream.py | 4 +- dimos/robot/drone/drone_tracking_module.py | 5 +- dimos/robot/drone/mavlink_connection.py | 7 +- dimos/robot/drone/test_drone.py | 36 ++-- dimos/robot/manipulators/__init__.py | 0 dimos/robot/manipulators/piper/__init__.py | 0 dimos/robot/manipulators/piper/blueprints.py | 8 +- dimos/robot/manipulators/xarm/__init__.py | 0 dimos/robot/manipulators/xarm/blueprints.py | 4 +- dimos/robot/unitree/__init__.py | 0 dimos/robot/unitree/b1/__init__.py | 8 - dimos/robot/unitree/b1/connection.py | 6 +- dimos/robot/unitree/b1/joystick_module.py | 6 +- dimos/robot/unitree/b1/test_connection.py | 3 +- dimos/robot/unitree/b1/unitree_b1.py | 5 +- dimos/robot/unitree/connection.py | 8 +- dimos/robot/unitree/g1/blueprints/__init__.py | 37 ----- .../unitree/g1/blueprints/agentic/__init__.py | 16 -- .../unitree/g1/blueprints/basic/__init__.py | 16 -- .../g1/blueprints/perceptive/__init__.py | 16 -- .../perceptive/unitree_g1_detection.py | 9 +- .../blueprints/perceptive/unitree_g1_shm.py | 2 +- .../g1/blueprints/primitive/__init__.py | 16 -- .../primitive/uintree_g1_primitive_no_nav.py | 20 ++- dimos/robot/unitree/g1/connection.py | 6 +- dimos/robot/unitree/g1/sim.py | 16 +- dimos/robot/unitree/g1/skill_container.py | 3 +- .../robot/unitree/go2/blueprints/__init__.py | 37 ----- .../go2/blueprints/agentic/__init__.py | 16 -- .../agentic/unitree_go2_temporal_memory.py | 5 +- .../unitree/go2/blueprints/basic/__init__.py | 16 -- .../go2/blueprints/basic/unitree_go2_basic.py | 4 +- .../go2/blueprints/basic/unitree_go2_fleet.py | 2 +- .../unitree/go2/blueprints/smart/__init__.py | 16 -- .../go2/blueprints/smart/_with_jpeg.py | 2 +- .../go2/blueprints/smart/unitree_go2.py | 4 +- .../blueprints/smart/unitree_go2_detection.py | 5 +- .../go2/blueprints/smart/unitree_go2_ros.py | 5 +- dimos/robot/unitree/go2/connection.py | 21 ++- dimos/robot/unitree/go2/fleet_connection.py | 2 +- dimos/robot/unitree/keyboard_teleop.py | 3 +- dimos/robot/unitree/modular/detect.py | 17 +- dimos/robot/unitree/mujoco_connection.py | 8 +- dimos/robot/unitree/rosnav.py | 4 +- dimos/robot/unitree/testing/__init__.py | 0 dimos/robot/unitree/testing/mock.py | 2 +- dimos/robot/unitree/testing/test_actors.py | 2 +- dimos/robot/unitree/testing/test_tooling.py | 2 +- dimos/robot/unitree/type/__init__.py | 0 dimos/robot/unitree/type/lidar.py | 2 +- dimos/robot/unitree/type/map.py | 4 +- dimos/robot/unitree/type/odometry.py | 4 +- dimos/robot/unitree/type/test_lidar.py | 4 +- dimos/robot/unitree/type/test_odometry.py | 2 +- .../robot/unitree/unitree_skill_container.py | 4 +- dimos/robot/unitree_webrtc/type/__init__.py | 33 ---- dimos/rxpy_backpressure/__init__.py | 3 - dimos/simulation/__init__.py | 15 -- dimos/simulation/base/__init__.py | 0 dimos/simulation/engines/__init__.py | 25 --- dimos/simulation/engines/base.py | 2 +- dimos/simulation/engines/mujoco_engine.py | 2 +- .../engines/registry.py} | 19 ++- dimos/simulation/genesis/__init__.py | 4 - dimos/simulation/isaac/__init__.py | 4 - dimos/simulation/manipulators/__init__.py | 54 ------ .../manipulators/sim_manip_interface.py | 2 +- dimos/simulation/manipulators/sim_module.py | 6 +- .../manipulators/test_sim_module.py | 2 +- dimos/simulation/mujoco/mujoco_process.py | 2 +- dimos/simulation/mujoco/person_on_track.py | 2 +- dimos/simulation/mujoco/shared_memory.py | 2 +- dimos/simulation/sim_blueprints.py | 10 +- dimos/skills/__init__.py | 0 dimos/skills/rest/__init__.py | 0 dimos/skills/unitree/__init__.py | 0 dimos/spec/__init__.py | 14 -- dimos/spec/control.py | 2 +- dimos/spec/mapping.py | 4 +- dimos/spec/nav.py | 5 +- dimos/spec/perception.py | 5 +- dimos/stream/__init__.py | 0 dimos/stream/audio/__init__.py | 0 dimos/stream/video_providers/__init__.py | 0 dimos/teleop/__init__.py | 15 -- dimos/teleop/keyboard/__init__.py | 0 .../teleop/keyboard/keyboard_teleop_module.py | 2 +- dimos/teleop/phone/__init__.py | 33 ---- dimos/teleop/phone/phone_extensions.py | 4 +- dimos/teleop/phone/phone_teleop_module.py | 4 +- dimos/teleop/quest/__init__.py | 54 ------ dimos/teleop/quest/blueprints.py | 2 +- dimos/teleop/quest/quest_extensions.py | 3 +- dimos/teleop/quest/quest_teleop_module.py | 4 +- dimos/teleop/quest/quest_types.py | 4 +- dimos/teleop/utils/__init__.py | 15 -- dimos/teleop/utils/teleop_transforms.py | 2 +- dimos/teleop/utils/teleop_visualization.py | 2 +- .../__init__.py => test_no_init_files.py} | 26 ++- dimos/types/ros_polyfill.py | 2 +- dimos/types/test_timestamped.py | 6 +- dimos/utils/cli/__init__.py | 0 dimos/utils/cli/agentspy/demo_agentspy.py | 2 +- dimos/utils/decorators/__init__.py | 15 -- dimos/utils/decorators/test_decorators.py | 3 +- dimos/utils/demo_image_encoding.py | 2 +- dimos/utils/docs/test_doclinks.py | 5 +- dimos/utils/reactive.py | 2 +- dimos/utils/test_transform_utils.py | 5 +- dimos/utils/testing/__init__.py | 9 - dimos/utils/testing/test_moment.py | 9 +- dimos/utils/testing/test_replay.py | 2 +- dimos/utils/transform_utils.py | 5 +- dimos/visualization/rerun/bridge.py | 3 +- dimos/web/__init__.py | 0 dimos/web/dimos_interface/__init__.py | 12 -- dimos/web/dimos_interface/api/__init__.py | 0 dimos/web/websocket_vis/costmap_viz.py | 2 +- dimos/web/websocket_vis/path_history.py | 2 +- .../web/websocket_vis/websocket_vis_module.py | 10 +- docs/capabilities/navigation/native/index.md | 2 +- docs/usage/transports/index.md | 2 +- docs/usage/visualization.md | 2 +- examples/simplerobot/simplerobot.py | 6 +- pyproject.toml | 3 +- 431 files changed, 972 insertions(+), 3124 deletions(-) delete mode 100644 dimos/__init__.py delete mode 100644 dimos/agents/mcp/__init__.py delete mode 100644 dimos/agents_deprecated/__init__.py delete mode 100644 dimos/agents_deprecated/memory/__init__.py delete mode 100644 dimos/agents_deprecated/modules/__init__.py delete mode 100644 dimos/agents_deprecated/modules/gateway/__init__.py delete mode 100644 dimos/agents_deprecated/modules/gateway/utils.py delete mode 100644 dimos/agents_deprecated/prompt_builder/__init__.py delete mode 100644 dimos/agents_deprecated/tokenizer/__init__.py delete mode 100644 dimos/control/__init__.py delete mode 100644 dimos/control/tasks/__init__.py delete mode 100644 dimos/core/__init__.py delete mode 100644 dimos/core/introspection/__init__.py delete mode 100644 dimos/core/introspection/blueprint/__init__.py delete mode 100644 dimos/core/introspection/module/__init__.py delete mode 100644 dimos/core/resource_monitor/__init__.py delete mode 100644 dimos/exceptions/__init__.py delete mode 100644 dimos/hardware/__init__.py delete mode 100644 dimos/hardware/drive_trains/__init__.py delete mode 100644 dimos/hardware/drive_trains/flowbase/__init__.py delete mode 100644 dimos/hardware/drive_trains/mock/__init__.py delete mode 100644 dimos/hardware/end_effectors/__init__.py delete mode 100644 dimos/hardware/manipulators/__init__.py delete mode 100644 dimos/hardware/manipulators/mock/__init__.py delete mode 100644 dimos/hardware/manipulators/piper/__init__.py delete mode 100644 dimos/hardware/manipulators/xarm/__init__.py delete mode 100644 dimos/hardware/sensors/camera/realsense/__init__.py rename dimos/hardware/sensors/camera/zed/{__init__.py => compat.py} (97%) delete mode 100644 dimos/hardware/sensors/lidar/__init__.py delete mode 100644 dimos/hardware/sensors/lidar/fastlio2/__init__.py delete mode 100644 dimos/hardware/sensors/lidar/livox/__init__.py delete mode 100644 dimos/manipulation/__init__.py delete mode 100644 dimos/manipulation/control/__init__.py delete mode 100644 dimos/manipulation/control/servo_control/__init__.py delete mode 100644 dimos/manipulation/control/trajectory_controller/__init__.py delete mode 100644 dimos/manipulation/grasping/__init__.py delete mode 100644 dimos/manipulation/planning/__init__.py delete mode 100644 dimos/manipulation/planning/examples/__init__.py delete mode 100644 dimos/manipulation/planning/kinematics/__init__.py delete mode 100644 dimos/manipulation/planning/monitor/__init__.py delete mode 100644 dimos/manipulation/planning/planners/__init__.py delete mode 100644 dimos/manipulation/planning/spec/__init__.py rename dimos/manipulation/planning/spec/{types.py => models.py} (97%) delete mode 100644 dimos/manipulation/planning/trajectory_generator/__init__.py delete mode 100644 dimos/manipulation/planning/utils/__init__.py delete mode 100644 dimos/manipulation/planning/world/__init__.py delete mode 100644 dimos/mapping/__init__.py rename dimos/mapping/google_maps/{types.py => models.py} (100%) rename dimos/mapping/{types.py => models.py} (100%) delete mode 100644 dimos/mapping/osm/__init__.py delete mode 100644 dimos/memory/timeseries/__init__.py delete mode 100644 dimos/models/__init__.py delete mode 100644 dimos/models/embedding/__init__.py delete mode 100644 dimos/models/vl/__init__.py delete mode 100644 dimos/msgs/__init__.py delete mode 100644 dimos/msgs/foxglove_msgs/__init__.py delete mode 100644 dimos/msgs/geometry_msgs/__init__.py delete mode 100644 dimos/msgs/nav_msgs/__init__.py delete mode 100644 dimos/msgs/sensor_msgs/__init__.py delete mode 100644 dimos/msgs/std_msgs/__init__.py delete mode 100644 dimos/msgs/tf2_msgs/__init__.py delete mode 100644 dimos/msgs/trajectory_msgs/__init__.py delete mode 100644 dimos/msgs/vision_msgs/__init__.py delete mode 100644 dimos/navigation/frontier_exploration/__init__.py delete mode 100644 dimos/perception/__init__.py delete mode 100644 dimos/perception/common/__init__.py delete mode 100644 dimos/perception/detection/__init__.py delete mode 100644 dimos/perception/detection/detectors/__init__.py rename dimos/perception/detection/detectors/{types.py => base.py} (84%) delete mode 100644 dimos/perception/detection/reid/__init__.py delete mode 100644 dimos/perception/detection/type/__init__.py delete mode 100644 dimos/perception/detection/type/detection2d/__init__.py delete mode 100644 dimos/perception/detection/type/detection3d/__init__.py delete mode 100644 dimos/perception/experimental/__init__.py delete mode 100644 dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py delete mode 100644 dimos/protocol/__init__.py rename dimos/protocol/encode/{__init__.py => encoder.py} (82%) delete mode 100644 dimos/protocol/pubsub/__init__.py delete mode 100644 dimos/protocol/pubsub/impl/__init__.py delete mode 100644 dimos/protocol/rpc/__init__.py delete mode 100644 dimos/protocol/service/__init__.py rename dimos/protocol/service/system_configurator/{__init__.py => lcm_config.py} (54%) delete mode 100644 dimos/protocol/tf/__init__.py delete mode 100644 dimos/robot/__init__.py delete mode 100644 dimos/robot/drone/__init__.py delete mode 100644 dimos/robot/drone/blueprints/__init__.py delete mode 100644 dimos/robot/drone/blueprints/agentic/__init__.py delete mode 100644 dimos/robot/drone/blueprints/basic/__init__.py delete mode 100644 dimos/robot/manipulators/__init__.py delete mode 100644 dimos/robot/manipulators/piper/__init__.py delete mode 100644 dimos/robot/manipulators/xarm/__init__.py delete mode 100644 dimos/robot/unitree/__init__.py delete mode 100644 dimos/robot/unitree/b1/__init__.py delete mode 100644 dimos/robot/unitree/g1/blueprints/__init__.py delete mode 100644 dimos/robot/unitree/g1/blueprints/agentic/__init__.py delete mode 100644 dimos/robot/unitree/g1/blueprints/basic/__init__.py delete mode 100644 dimos/robot/unitree/g1/blueprints/perceptive/__init__.py delete mode 100644 dimos/robot/unitree/g1/blueprints/primitive/__init__.py delete mode 100644 dimos/robot/unitree/go2/blueprints/__init__.py delete mode 100644 dimos/robot/unitree/go2/blueprints/agentic/__init__.py delete mode 100644 dimos/robot/unitree/go2/blueprints/basic/__init__.py delete mode 100644 dimos/robot/unitree/go2/blueprints/smart/__init__.py delete mode 100644 dimos/robot/unitree/testing/__init__.py delete mode 100644 dimos/robot/unitree/type/__init__.py delete mode 100644 dimos/robot/unitree_webrtc/type/__init__.py delete mode 100644 dimos/rxpy_backpressure/__init__.py delete mode 100644 dimos/simulation/__init__.py delete mode 100644 dimos/simulation/base/__init__.py delete mode 100644 dimos/simulation/engines/__init__.py rename dimos/{msgs/visualization_msgs/__init__.py => simulation/engines/registry.py} (56%) delete mode 100644 dimos/simulation/genesis/__init__.py delete mode 100644 dimos/simulation/isaac/__init__.py delete mode 100644 dimos/simulation/manipulators/__init__.py delete mode 100644 dimos/skills/__init__.py delete mode 100644 dimos/skills/rest/__init__.py delete mode 100644 dimos/skills/unitree/__init__.py delete mode 100644 dimos/spec/__init__.py delete mode 100644 dimos/stream/__init__.py delete mode 100644 dimos/stream/audio/__init__.py delete mode 100644 dimos/stream/video_providers/__init__.py delete mode 100644 dimos/teleop/__init__.py delete mode 100644 dimos/teleop/keyboard/__init__.py delete mode 100644 dimos/teleop/phone/__init__.py delete mode 100644 dimos/teleop/quest/__init__.py delete mode 100644 dimos/teleop/utils/__init__.py rename dimos/{perception/experimental/temporal_memory/__init__.py => test_no_init_files.py} (50%) delete mode 100644 dimos/utils/cli/__init__.py delete mode 100644 dimos/utils/decorators/__init__.py delete mode 100644 dimos/utils/testing/__init__.py delete mode 100644 dimos/web/__init__.py delete mode 100644 dimos/web/dimos_interface/__init__.py delete mode 100644 dimos/web/dimos_interface/api/__init__.py diff --git a/dimos/__init__.py b/dimos/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index ab576fb109..6e24cee870 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -30,7 +30,7 @@ from dimos.core.module import Module, ModuleConfig, SkillInfo from dimos.core.rpc_client import RpcCall, RPCClient from dimos.core.stream import In, Out -from dimos.protocol.rpc import RPCSpec +from dimos.protocol.rpc.spec import RPCSpec from dimos.spec.utils import Spec if TYPE_CHECKING: diff --git a/dimos/agents/demo_agent.py b/dimos/agents/demo_agent.py index bd69fc6cae..b839b0809c 100644 --- a/dimos/agents/demo_agent.py +++ b/dimos/agents/demo_agent.py @@ -14,9 +14,9 @@ from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera import zed from dimos.hardware.sensors.camera.module import camera_module from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.hardware.sensors.camera.zed import compat as zed demo_agent = autoconnect(Agent.blueprint()) diff --git a/dimos/agents/mcp/__init__.py b/dimos/agents/mcp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py index 56b98c3cd2..c903e5f11c 100644 --- a/dimos/agents/mcp/test_mcp_client.py +++ b/dimos/agents/mcp/test_mcp_client.py @@ -19,7 +19,7 @@ from dimos.agents.annotation import skill from dimos.core.module import Module -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/agents/skills/demo_robot.py b/dimos/agents/skills/demo_robot.py index aa4e81e2cc..789e26d7e1 100644 --- a/dimos/agents/skills/demo_robot.py +++ b/dimos/agents/skills/demo_robot.py @@ -17,7 +17,7 @@ from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon class DemoRobot(Module): diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index c03932924f..e218601696 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -20,7 +20,7 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.mapping.google_maps.google_maps import GoogleMaps -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 63cf4a3dd3..1464665131 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -19,7 +19,7 @@ from dimos.core.module import Module from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.utils.logging_config import setup_logger diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index 8442846f32..47ae21c799 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -22,9 +22,10 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.models.qwen.bbox import BBox -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.geometry_msgs.Vector3 import make_vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3, make_vector3 +from dimos.msgs.sensor_msgs.Image import Image from dimos.navigation.base import NavigationState from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.types.robot_location import RobotLocation diff --git a/dimos/agents/skills/osm.py b/dimos/agents/skills/osm.py index 613bc0806e..d0281fb808 100644 --- a/dimos/agents/skills/osm.py +++ b/dimos/agents/skills/osm.py @@ -16,8 +16,8 @@ from dimos.agents.annotation import skill from dimos.core.module import Module from dimos.core.stream import In +from dimos.mapping.models import LatLon from dimos.mapping.osm.current_location_map import CurrentLocationMap -from dimos.mapping.types import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.models.vl.qwen import QwenVlModel from dimos.utils.logging_config import setup_logger diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 7a6c6ecfe9..f1cafed6cd 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -29,8 +29,10 @@ from dimos.models.segmentation.edge_tam import EdgeTAMProcessor from dimos.models.vl.base import VlModel from dimos.models.vl.create import create -from dimos.msgs.geometry_msgs import Twist -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D diff --git a/dimos/agents/skills/test_google_maps_skill_container.py b/dimos/agents/skills/test_google_maps_skill_container.py index 1519f9d1df..376d6d306e 100644 --- a/dimos/agents/skills/test_google_maps_skill_container.py +++ b/dimos/agents/skills/test_google_maps_skill_container.py @@ -21,8 +21,8 @@ from dimos.agents.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position -from dimos.mapping.types import LatLon +from dimos.mapping.google_maps.models import Coordinates, LocationContext, Position +from dimos.mapping.models import LatLon class FakeGPS(Module): diff --git a/dimos/agents/skills/test_gps_nav_skills.py b/dimos/agents/skills/test_gps_nav_skills.py index 4060b1814e..c1e380ccd1 100644 --- a/dimos/agents/skills/test_gps_nav_skills.py +++ b/dimos/agents/skills/test_gps_nav_skills.py @@ -18,7 +18,7 @@ from dimos.agents.skills.gps_nav_skill import GpsNavSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon class FakeGPS(Module): diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py index e31fae93b5..e4a60db081 100644 --- a/dimos/agents/skills/test_navigation.py +++ b/dimos/agents/skills/test_navigation.py @@ -18,8 +18,8 @@ from dimos.agents.skills.navigation import NavigationSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image class FakeCamera(Module): diff --git a/dimos/agents/test_agent.py b/dimos/agents/test_agent.py index bb6caa6337..e925e52a4a 100644 --- a/dimos/agents/test_agent.py +++ b/dimos/agents/test_agent.py @@ -19,7 +19,7 @@ from dimos.agents.annotation import skill from dimos.core.module import Module -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index c39f79830a..81bad79ae5 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -21,7 +21,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py index 4126c6b3a0..5f2165dc8d 100644 --- a/dimos/agents/vlm_stream_tester.py +++ b/dimos/agents/vlm_stream_tester.py @@ -20,7 +20,7 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/agents_deprecated/__init__.py b/dimos/agents_deprecated/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/memory/__init__.py b/dimos/agents_deprecated/memory/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/modules/__init__.py b/dimos/agents_deprecated/modules/__init__.py deleted file mode 100644 index 99163d55d0..0000000000 --- a/dimos/agents_deprecated/modules/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agent modules for DimOS.""" diff --git a/dimos/agents_deprecated/modules/base.py b/dimos/agents_deprecated/modules/base.py index 891edbe4bd..0927e184fc 100644 --- a/dimos/agents_deprecated/modules/base.py +++ b/dimos/agents_deprecated/modules/base.py @@ -29,9 +29,9 @@ from dimos.utils.logging_config import setup_logger try: - from .gateway import UnifiedGatewayClient + from dimos.agents_deprecated.modules.gateway.client import UnifiedGatewayClient except ImportError: - from dimos.agents_deprecated.modules.gateway import UnifiedGatewayClient + from dimos.agents_deprecated.modules.gateway.client import UnifiedGatewayClient logger = setup_logger() diff --git a/dimos/agents_deprecated/modules/gateway/__init__.py b/dimos/agents_deprecated/modules/gateway/__init__.py deleted file mode 100644 index 58ed40cd95..0000000000 --- a/dimos/agents_deprecated/modules/gateway/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Gateway module for unified LLM access.""" - -from .client import UnifiedGatewayClient -from .utils import convert_tools_to_standard_format, parse_streaming_response - -__all__ = ["UnifiedGatewayClient", "convert_tools_to_standard_format", "parse_streaming_response"] diff --git a/dimos/agents_deprecated/modules/gateway/utils.py b/dimos/agents_deprecated/modules/gateway/utils.py deleted file mode 100644 index 526d3b9724..0000000000 --- a/dimos/agents_deprecated/modules/gateway/utils.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for gateway operations.""" - -import logging -from typing import Any - -logger = logging.getLogger(__name__) - - -def convert_tools_to_standard_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Convert DimOS tool format to standard format accepted by gateways. - - DimOS tools come from pydantic_function_tool and have this format: - { - "type": "function", - "function": { - "name": "tool_name", - "description": "tool description", - "parameters": { - "type": "object", - "properties": {...}, - "required": [...] - } - } - } - - We keep this format as it's already standard JSON Schema format. - """ - if not tools: - return [] - - # Tools are already in the correct format from pydantic_function_tool - return tools - - -def parse_streaming_response(chunk: dict[str, Any]) -> dict[str, Any]: - """Parse a streaming response chunk into a standard format. - - Args: - chunk: Raw chunk from the gateway - - Returns: - Parsed chunk with standard fields: - - type: "content" | "tool_call" | "error" | "done" - - content: The actual content (text for content type, tool info for tool_call) - - metadata: Additional information - """ - # Handle TensorZero streaming format - if "choices" in chunk: - # OpenAI-style format from TensorZero - choice = chunk["choices"][0] if chunk["choices"] else {} - delta = choice.get("delta", {}) - - if "content" in delta: - return { - "type": "content", - "content": delta["content"], - "metadata": {"index": choice.get("index", 0)}, - } - elif "tool_calls" in delta: - tool_calls = delta["tool_calls"] - if tool_calls: - tool_call = tool_calls[0] - return { - "type": "tool_call", - "content": { - "id": tool_call.get("id"), - "name": tool_call.get("function", {}).get("name"), - "arguments": tool_call.get("function", {}).get("arguments", ""), - }, - "metadata": {"index": tool_call.get("index", 0)}, - } - elif choice.get("finish_reason"): - return { - "type": "done", - "content": None, - "metadata": {"finish_reason": choice["finish_reason"]}, - } - - # Handle direct content chunks - if isinstance(chunk, str): - return {"type": "content", "content": chunk, "metadata": {}} - - # Handle error responses - if "error" in chunk: - return {"type": "error", "content": chunk["error"], "metadata": chunk} - - # Default fallback - return {"type": "unknown", "content": chunk, "metadata": {}} - - -def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> dict[str, Any]: - """Create a properly formatted tool response. - - Args: - tool_id: The ID of the tool call - result: The result from executing the tool - is_error: Whether this is an error response - - Returns: - Formatted tool response message - """ - content = str(result) if not isinstance(result, str) else result - - return { - "role": "tool", - "tool_call_id": tool_id, - "content": content, - "name": None, # Will be filled by the calling code - } - - -def extract_image_from_message(message: dict[str, Any]) -> dict[str, Any] | None: - """Extract image data from a message if present. - - Args: - message: Message dict that may contain image data - - Returns: - Dict with image data and metadata, or None if no image - """ - content = message.get("content", []) - - # Handle list content (multimodal) - if isinstance(content, list): - for item in content: - if isinstance(item, dict): - # OpenAI format - if item.get("type") == "image_url": - return { - "format": "openai", - "data": item["image_url"]["url"], - "detail": item["image_url"].get("detail", "auto"), - } - # Anthropic format - elif item.get("type") == "image": - return { - "format": "anthropic", - "data": item["source"]["data"], - "media_type": item["source"].get("media_type", "image/jpeg"), - } - - return None diff --git a/dimos/agents_deprecated/prompt_builder/__init__.py b/dimos/agents_deprecated/prompt_builder/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/tokenizer/__init__.py b/dimos/agents_deprecated/tokenizer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/control/__init__.py b/dimos/control/__init__.py deleted file mode 100644 index 639f0ba38a..0000000000 --- a/dimos/control/__init__.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ControlCoordinator - Centralized control for multi-arm coordination. - -This module provides a centralized control coordinator that replaces -per-driver/per-controller loops with a single deterministic tick-based system. - -Features: -- Single tick loop (read -> compute -> arbitrate -> route -> write) -- Per-joint arbitration (highest priority wins) -- Mode conflict detection -- Partial command support (hold last value) -- Aggregated preemption notifications - -Example: - >>> from dimos.control import ControlCoordinator - >>> from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig - >>> from dimos.hardware.manipulators.xarm import XArmAdapter - >>> - >>> # Create coordinator - >>> coord = ControlCoordinator(tick_rate=100.0) - >>> - >>> # Add hardware - >>> adapter = XArmAdapter(ip="192.168.1.185", dof=7) - >>> adapter.connect() - >>> coord.add_hardware("left_arm", adapter) - >>> - >>> # Add task - >>> joints = [f"left_arm_joint{i+1}" for i in range(7)] - >>> task = JointTrajectoryTask( - ... "traj_left", - ... JointTrajectoryTaskConfig(joint_names=joints, priority=10), - ... ) - >>> coord.add_task(task) - >>> - >>> # Start - >>> coord.start() -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "components": [ - "HardwareComponent", - "HardwareId", - "HardwareType", - "JointName", - "JointState", - "make_gripper_joints", - "make_joints", - ], - "coordinator": [ - "ControlCoordinator", - "ControlCoordinatorConfig", - "TaskConfig", - "control_coordinator", - ], - "hardware_interface": ["ConnectedHardware"], - "task": [ - "ControlMode", - "ControlTask", - "CoordinatorState", - "JointCommandOutput", - "JointStateSnapshot", - "ResourceClaim", - ], - "tick_loop": ["TickLoop"], - }, -) diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py index 7c6036b20c..fff2083322 100644 --- a/dimos/control/blueprints.py +++ b/dimos/control/blueprints.py @@ -39,8 +39,9 @@ ) from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import Buttons from dimos.utils.data import LfsPath diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 16f4e53f46..0757f27705 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -49,13 +49,9 @@ TwistBaseAdapter, ) from dimos.hardware.manipulators.spec import ManipulatorAdapter -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Twist, -) -from dimos.msgs.sensor_msgs import ( - JointState, -) +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import ( Buttons, ) @@ -258,7 +254,10 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: task_type = cfg.type.lower() if task_type == "trajectory": - from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig + from dimos.control.tasks.trajectory_task import ( + JointTrajectoryTask, + JointTrajectoryTaskConfig, + ) return JointTrajectoryTask( cfg.name, @@ -269,7 +268,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "servo": - from dimos.control.tasks import JointServoTask, JointServoTaskConfig + from dimos.control.tasks.servo_task import JointServoTask, JointServoTaskConfig return JointServoTask( cfg.name, @@ -280,7 +279,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "velocity": - from dimos.control.tasks import JointVelocityTask, JointVelocityTaskConfig + from dimos.control.tasks.velocity_task import JointVelocityTask, JointVelocityTaskConfig return JointVelocityTask( cfg.name, @@ -291,7 +290,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "cartesian_ik": - from dimos.control.tasks import CartesianIKTask, CartesianIKTaskConfig + from dimos.control.tasks.cartesian_ik_task import CartesianIKTask, CartesianIKTaskConfig if cfg.model_path is None: raise ValueError(f"CartesianIKTask '{cfg.name}' requires model_path in TaskConfig") diff --git a/dimos/control/examples/cartesian_ik_jogger.py b/dimos/control/examples/cartesian_ik_jogger.py index d2a2f4d119..bf3b36a972 100644 --- a/dimos/control/examples/cartesian_ik_jogger.py +++ b/dimos/control/examples/cartesian_ik_jogger.py @@ -116,7 +116,7 @@ def to_pose_stamped(self, task_name: str) -> Any: Args: task_name: Task name to use as frame_id for routing """ - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -168,7 +168,7 @@ def run_jogger_ui(model_path: str | None = None, ee_joint_id: int = 6) -> None: ee_joint_id: End-effector joint ID in the model """ from dimos.core.transport import LCMTransport - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped # Use Piper model if not specified if model_path is None: diff --git a/dimos/control/task.py b/dimos/control/task.py index c9ef03fbf0..afad70bb05 100644 --- a/dimos/control/task.py +++ b/dimos/control/task.py @@ -34,7 +34,8 @@ from dimos.hardware.manipulators.spec import ControlMode if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_types import Buttons diff --git a/dimos/control/tasks/__init__.py b/dimos/control/tasks/__init__.py deleted file mode 100644 index 5b869b01f9..0000000000 --- a/dimos/control/tasks/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Task implementations for the ControlCoordinator.""" - -from dimos.control.tasks.cartesian_ik_task import ( - CartesianIKTask, - CartesianIKTaskConfig, -) -from dimos.control.tasks.servo_task import ( - JointServoTask, - JointServoTaskConfig, -) -from dimos.control.tasks.teleop_task import ( - TeleopIKTask, - TeleopIKTaskConfig, -) -from dimos.control.tasks.trajectory_task import ( - JointTrajectoryTask, - JointTrajectoryTaskConfig, -) -from dimos.control.tasks.velocity_task import ( - JointVelocityTask, - JointVelocityTaskConfig, -) - -__all__ = [ - "CartesianIKTask", - "CartesianIKTaskConfig", - "JointServoTask", - "JointServoTaskConfig", - "JointTrajectoryTask", - "JointTrajectoryTaskConfig", - "JointVelocityTask", - "JointVelocityTaskConfig", - "TeleopIKTask", - "TeleopIKTaskConfig", -] diff --git a/dimos/control/tasks/cartesian_ik_task.py b/dimos/control/tasks/cartesian_ik_task.py index 67d4e4ed52..2525db69e6 100644 --- a/dimos/control/tasks/cartesian_ik_task.py +++ b/dimos/control/tasks/cartesian_ik_task.py @@ -50,7 +50,8 @@ from numpy.typing import NDArray import pinocchio # type: ignore[import-untyped] - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() diff --git a/dimos/control/tasks/teleop_task.py b/dimos/control/tasks/teleop_task.py index 115b455fe6..3f20502759 100644 --- a/dimos/control/tasks/teleop_task.py +++ b/dimos/control/tasks/teleop_task.py @@ -51,7 +51,8 @@ from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_types import Buttons logger = setup_logger() diff --git a/dimos/control/tasks/trajectory_task.py b/dimos/control/tasks/trajectory_task.py index 16a271018a..fd0a9fda6e 100644 --- a/dimos/control/tasks/trajectory_task.py +++ b/dimos/control/tasks/trajectory_task.py @@ -32,7 +32,8 @@ JointCommandOutput, ResourceClaim, ) -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/control/test_control.py b/dimos/control/test_control.py index a4b7e0a5bc..3de7865ae3 100644 --- a/dimos/control/test_control.py +++ b/dimos/control/test_control.py @@ -38,7 +38,8 @@ ) from dimos.control.tick_loop import TickLoop from dimos.hardware.manipulators.spec import ManipulatorAdapter -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint @pytest.fixture diff --git a/dimos/control/tick_loop.py b/dimos/control/tick_loop.py index e45a17030b..dc1ed32dbb 100644 --- a/dimos/control/tick_loop.py +++ b/dimos/control/tick_loop.py @@ -38,7 +38,7 @@ JointStateSnapshot, ResourceClaim, ) -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index abfeb29b2f..cac8507881 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -209,7 +209,8 @@ def _is_name_unique(self, name: str) -> bool: return sum(1 for n, _ in self._all_name_types if n == name) == 1 def _run_configurators(self) -> None: - from dimos.protocol.service.system_configurator import configure_system, lcm_configurators + from dimos.protocol.service.system_configurator.base import configure_system + from dimos.protocol.service.system_configurator.lcm_config import lcm_configurators configurators = [*lcm_configurators(), *self.configurator_checks] diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index 99833a9b97..dcb75fbdee 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -28,7 +28,7 @@ from dimos.core.docker_build import build_image, image_exists from dimos.core.module import Module, ModuleConfig from dimos.core.rpc_client import RpcCall -from dimos.protocol.rpc import LCMRPC +from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT diff --git a/dimos/core/introspection/__init__.py b/dimos/core/introspection/__init__.py deleted file mode 100644 index c40c3d49e6..0000000000 --- a/dimos/core/introspection/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module and blueprint introspection utilities.""" - -from dimos.core.introspection.module import INTERNAL_RPCS, render_module_io -from dimos.core.introspection.svg import to_svg - -__all__ = ["INTERNAL_RPCS", "render_module_io", "to_svg"] diff --git a/dimos/core/introspection/blueprint/__init__.py b/dimos/core/introspection/blueprint/__init__.py deleted file mode 100644 index 6545b39dfa..0000000000 --- a/dimos/core/introspection/blueprint/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Blueprint introspection and rendering. - -Renderers: - - dot: Graphviz DOT format (hub-style with type nodes as intermediate hubs) -""" - -from dimos.core.introspection.blueprint import dot -from dimos.core.introspection.blueprint.dot import LayoutAlgo, render_svg - -__all__ = ["LayoutAlgo", "dot", "render_svg"] diff --git a/dimos/core/introspection/module/__init__.py b/dimos/core/introspection/module/__init__.py deleted file mode 100644 index 444d0e24f3..0000000000 --- a/dimos/core/introspection/module/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module introspection and rendering. - -Renderers: - - ansi: ANSI terminal output (default) - - dot: Graphviz DOT format -""" - -from dimos.core.introspection.module import ansi, dot -from dimos.core.introspection.module.info import ( - INTERNAL_RPCS, - ModuleInfo, - ParamInfo, - RpcInfo, - SkillInfo, - StreamInfo, - extract_module_info, -) -from dimos.core.introspection.module.render import render_module_io - -__all__ = [ - "INTERNAL_RPCS", - "ModuleInfo", - "ParamInfo", - "RpcInfo", - "SkillInfo", - "StreamInfo", - "ansi", - "dot", - "extract_module_info", - "render_module_io", -] diff --git a/dimos/core/module.py b/dimos/core/module.py index ab21ce17a9..1c5b311883 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -34,19 +34,21 @@ from dimos.core.core import T, rpc from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.introspection.module import extract_module_info, render_module_io +from dimos.core.introspection.module.info import extract_module_info +from dimos.core.introspection.module.render import render_module_io from dimos.core.resource import Resource from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport -from dimos.protocol.rpc import LCMRPC, RPCSpec -from dimos.protocol.service import BaseConfig, Configurable -from dimos.protocol.tf import LCMTF, TFSpec +from dimos.protocol.rpc.pubsubrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCSpec +from dimos.protocol.service.spec import BaseConfig, Configurable +from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty if TYPE_CHECKING: from dimos.core.blueprints import Blueprint - from dimos.core.introspection.module import ModuleInfo + from dimos.core.introspection.module.info import ModuleInfo from dimos.core.rpc_client import RPCClient if sys.version_info >= (3, 13): diff --git a/dimos/core/resource_monitor/__init__.py b/dimos/core/resource_monitor/__init__.py deleted file mode 100644 index 217941a2ec..0000000000 --- a/dimos/core/resource_monitor/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from dimos.core.resource_monitor.logger import ( - LCMResourceLogger, - ResourceLogger, - StructlogResourceLogger, -) -from dimos.core.resource_monitor.monitor import StatsMonitor -from dimos.core.resource_monitor.stats import ProcessStats, WorkerStats, collect_process_stats - -__all__ = [ - "LCMResourceLogger", - "ProcessStats", - "ResourceLogger", - "StatsMonitor", - "StructlogResourceLogger", - "WorkerStats", - "collect_process_stats", -] diff --git a/dimos/core/resource_monitor/stats.py b/dimos/core/resource_monitor/stats.py index 485132db46..6264d5c7f9 100644 --- a/dimos/core/resource_monitor/stats.py +++ b/dimos/core/resource_monitor/stats.py @@ -19,7 +19,7 @@ import psutil -from dimos.utils.decorators import ttl_cache +from dimos.utils.decorators.decorators import ttl_cache # Cache Process objects so cpu_percent(interval=None) has a previous sample. _proc_cache: dict[int, psutil.Process] = {} diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index e46124469c..84de18d671 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -17,7 +17,8 @@ from dimos.core.stream import RemoteStream from dimos.core.worker import MethodCallProxy -from dimos.protocol.rpc import LCMRPC, RPCSpec +from dimos.protocol.rpc.pubsubrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCSpec from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 19dbf62c74..5f7bf33b8b 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -33,7 +33,7 @@ from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.spec.utils import Spec # Disable Rerun for tests (prevents viewer spawn and gRPC flush errors) diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3bd1383761..f9a89829d5 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -22,8 +22,8 @@ from dimos.core.stream import In, Out from dimos.core.testing import MockRobotClient from dimos.core.transport import LCMTransport, pLCMTransport -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 16cb44b907..fdea17d2a3 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -24,7 +24,7 @@ from dimos.core.stream import In from dimos.core.testing import MockRobotClient from dimos.core.transport import LCMTransport, pLCMTransport -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index 306b3fdb3d..021b2e21c4 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -21,7 +21,7 @@ from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.core.worker_manager import WorkerManager -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 if TYPE_CHECKING: from dimos.core.resource_monitor.stats import WorkerStats diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 3bb5865192..a128fc4767 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -19,11 +19,11 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay class MockRobotClient(Module): diff --git a/dimos/e2e_tests/conftest.py b/dimos/e2e_tests/conftest.py index 51ab7c2c18..12f4a674a6 100644 --- a/dimos/e2e_tests/conftest.py +++ b/dimos/e2e_tests/conftest.py @@ -22,7 +22,8 @@ from dimos.e2e_tests.conf_types import StartPersonTrack from dimos.e2e_tests.dimos_cli_call import DimosCliCall from dimos.e2e_tests.lcm_spy import LcmSpy -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.std_msgs.Bool import Bool from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher diff --git a/dimos/e2e_tests/lcm_spy.py b/dimos/e2e_tests/lcm_spy.py index 9efed09d5e..030591f52e 100644 --- a/dimos/e2e_tests/lcm_spy.py +++ b/dimos/e2e_tests/lcm_spy.py @@ -22,8 +22,8 @@ import lcm -from dimos.msgs import DimosMsg -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.protocol import DimosMsg from dimos.protocol.service.lcmservice import LCMService diff --git a/dimos/e2e_tests/test_control_coordinator.py b/dimos/e2e_tests/test_control_coordinator.py index 5bb7a096f7..80b63c529f 100644 --- a/dimos/e2e_tests/test_control_coordinator.py +++ b/dimos/e2e_tests/test_control_coordinator.py @@ -24,8 +24,10 @@ from dimos.control.coordinator import ControlCoordinator from dimos.core.rpc_client import RPCClient -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint, TrajectoryState +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState @pytest.mark.skipif_in_ci diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py index b5902ad7e2..e08183fc24 100644 --- a/dimos/e2e_tests/test_simulation_module.py +++ b/dimos/e2e_tests/test_simulation_module.py @@ -16,7 +16,9 @@ import pytest -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState def _positions_within_tolerance( diff --git a/dimos/exceptions/__init__.py b/dimos/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/__init__.py b/dimos/hardware/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/drive_trains/__init__.py b/dimos/hardware/drive_trains/__init__.py deleted file mode 100644 index c6e843feea..0000000000 --- a/dimos/hardware/drive_trains/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Drive train hardware adapters for velocity-commanded platforms.""" diff --git a/dimos/hardware/drive_trains/flowbase/__init__.py b/dimos/hardware/drive_trains/flowbase/__init__.py deleted file mode 100644 index 25f95e399c..0000000000 --- a/dimos/hardware/drive_trains/flowbase/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""FlowBase twist base adapter for holonomic base control via Portal RPC.""" diff --git a/dimos/hardware/drive_trains/mock/__init__.py b/dimos/hardware/drive_trains/mock/__init__.py deleted file mode 100644 index 9b6f630040..0000000000 --- a/dimos/hardware/drive_trains/mock/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mock twist base adapter for testing without hardware. - -Usage: - >>> from dimos.hardware.drive_trains.mock import MockTwistBaseAdapter - >>> adapter = MockTwistBaseAdapter(dof=3) - >>> adapter.connect() - True - >>> adapter.write_velocities([0.5, 0.0, 0.1]) - True - >>> adapter.read_velocities() - [0.5, 0.0, 0.1] -""" - -from dimos.hardware.drive_trains.mock.adapter import MockTwistBaseAdapter - -__all__ = ["MockTwistBaseAdapter"] diff --git a/dimos/hardware/end_effectors/__init__.py b/dimos/hardware/end_effectors/__init__.py deleted file mode 100644 index 9a7aa9759a..0000000000 --- a/dimos/hardware/end_effectors/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .end_effector import EndEffector - -__all__ = ["EndEffector"] diff --git a/dimos/hardware/manipulators/__init__.py b/dimos/hardware/manipulators/__init__.py deleted file mode 100644 index 58986c9211..0000000000 --- a/dimos/hardware/manipulators/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulator drivers for robotic arms. - -Architecture: Protocol-based adapters for different manipulator hardware. -- spec.py: ManipulatorAdapter Protocol and shared types -- xarm/: XArm adapter -- piper/: Piper adapter -- mock/: Mock adapter for testing - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> arm = XArm(ip="192.168.1.185") - >>> arm.start() - >>> arm.enable_servos() - >>> arm.move_joint([0, 0, 0, 0, 0, 0]) - -Testing: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> from dimos.hardware.manipulators.mock import MockAdapter - >>> arm = XArm(adapter=MockAdapter()) - >>> arm.start() # No hardware needed! -""" - -from dimos.hardware.manipulators.spec import ( - ControlMode, - DriverStatus, - JointLimits, - ManipulatorAdapter, - ManipulatorInfo, -) - -__all__ = [ - "ControlMode", - "DriverStatus", - "JointLimits", - "ManipulatorAdapter", - "ManipulatorInfo", -] diff --git a/dimos/hardware/manipulators/mock/__init__.py b/dimos/hardware/manipulators/mock/__init__.py deleted file mode 100644 index 63be6f7e98..0000000000 --- a/dimos/hardware/manipulators/mock/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mock adapter for testing manipulator drivers without hardware. - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> from dimos.hardware.manipulators.mock import MockAdapter - >>> arm = XArm(adapter=MockAdapter()) - >>> arm.start() # No hardware needed! - >>> arm.move_joint([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) - >>> assert arm.adapter.read_joint_positions() == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] -""" - -from dimos.hardware.manipulators.mock.adapter import MockAdapter - -__all__ = ["MockAdapter"] diff --git a/dimos/hardware/manipulators/piper/__init__.py b/dimos/hardware/manipulators/piper/__init__.py deleted file mode 100644 index bfeb89b1c0..0000000000 --- a/dimos/hardware/manipulators/piper/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Piper manipulator hardware adapter. - -Usage: - >>> from dimos.hardware.manipulators.piper import PiperAdapter - >>> adapter = PiperAdapter(can_port="can0") - >>> adapter.connect() - >>> positions = adapter.read_joint_positions() -""" - -from dimos.hardware.manipulators.piper.adapter import PiperAdapter - -__all__ = ["PiperAdapter"] diff --git a/dimos/hardware/manipulators/registry.py b/dimos/hardware/manipulators/registry.py index 65dbe74b50..9e63fa349b 100644 --- a/dimos/hardware/manipulators/registry.py +++ b/dimos/hardware/manipulators/registry.py @@ -33,7 +33,6 @@ import importlib import logging -import pkgutil from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -78,19 +77,25 @@ def available(self) -> list[str]: def discover(self) -> None: """Discover and register adapters from subpackages. + Scans for subdirectories containing an adapter.py module. Can be called multiple times to pick up newly added adapters. """ - import dimos.hardware.manipulators as pkg + from pathlib import Path - for _, name, ispkg in pkgutil.iter_modules(pkg.__path__): - if not ispkg: + pkg_dir = Path(__file__).parent + for child in sorted(pkg_dir.iterdir()): + if not child.is_dir() or child.name.startswith(("_", ".")): + continue + if not (child / "adapter.py").exists(): continue try: - module = importlib.import_module(f"dimos.hardware.manipulators.{name}.adapter") + module = importlib.import_module( + f"dimos.hardware.manipulators.{child.name}.adapter" + ) if hasattr(module, "register"): module.register(self) except ImportError as e: - logger.debug(f"Skipping adapter {name}: {e}") + logger.debug(f"Skipping adapter {child.name}: {e}") adapter_registry = AdapterRegistry() diff --git a/dimos/hardware/manipulators/spec.py b/dimos/hardware/manipulators/spec.py index ed63a21e82..868b714bfa 100644 --- a/dimos/hardware/manipulators/spec.py +++ b/dimos/hardware/manipulators/spec.py @@ -26,7 +26,9 @@ from enum import Enum from typing import Protocol, runtime_checkable -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class DriverStatus(Enum): diff --git a/dimos/hardware/manipulators/xarm/__init__.py b/dimos/hardware/manipulators/xarm/__init__.py deleted file mode 100644 index 8bcab667c1..0000000000 --- a/dimos/hardware/manipulators/xarm/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""XArm manipulator hardware adapter. - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArmAdapter - >>> adapter = XArmAdapter(ip="192.168.1.185", dof=6) - >>> adapter.connect() - >>> positions = adapter.read_joint_positions() -""" - -from dimos.hardware.manipulators.xarm.adapter import XArmAdapter - -__all__ = ["XArmAdapter"] diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py index ec19d6844e..c723cab130 100644 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py @@ -25,7 +25,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.logging_config import setup_logger # Add system path for gi module if needed diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py index 8785a9260b..a18d52fbb0 100755 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py @@ -21,8 +21,8 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.gstreamer.gstreamer_camera import GstreamerCameraModule -from dimos.msgs.sensor_msgs import Image -from dimos.protocol import pubsub +from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.impl import lcmpubsub as _lcm logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def main() -> None: logging.getLogger().setLevel(logging.DEBUG) # Initialize LCM - pubsub.lcm.autoconf() # type: ignore[attr-defined] + _lcm.autoconf() # type: ignore[attr-defined] # Start dimos dimos = ModuleCoordinator() diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 0f055f0352..e0d0b3407e 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -26,7 +26,9 @@ from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware from dimos.hardware.sensors.camera.webcam import Webcam -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception diff --git a/dimos/hardware/sensors/camera/realsense/__init__.py b/dimos/hardware/sensors/camera/realsense/__init__.py deleted file mode 100644 index 58f519a12e..0000000000 --- a/dimos/hardware/sensors/camera/realsense/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dimos.hardware.sensors.camera.realsense.camera import ( - RealSenseCamera, - RealSenseCameraConfig, - realsense_camera, - ) - -__all__ = ["RealSenseCamera", "RealSenseCameraConfig", "realsense_camera"] - - -def __getattr__(name: str) -> object: - if name in __all__: - from dimos.hardware.sensors.camera.realsense.camera import ( - RealSenseCamera, - RealSenseCameraConfig, - realsense_camera, - ) - - globals().update( - RealSenseCamera=RealSenseCamera, - RealSenseCameraConfig=RealSenseCameraConfig, - realsense_camera=realsense_camera, - ) - return globals()[name] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index 5908525826..23bc19cdad 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -35,8 +35,10 @@ DepthCameraConfig, DepthCameraHardware, ) -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.foxglove_bridge import FoxgloveBridge diff --git a/dimos/hardware/sensors/camera/spec.py b/dimos/hardware/sensors/camera/spec.py index be37ec734a..dcb0196ff2 100644 --- a/dimos/hardware/sensors/camera/spec.py +++ b/dimos/hardware/sensors/camera/spec.py @@ -17,8 +17,9 @@ from reactivex.observable import Observable -from dimos.msgs.geometry_msgs import Quaternion, Transform -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image from dimos.protocol.service.spec import BaseConfig, Configurable diff --git a/dimos/hardware/sensors/camera/webcam.py b/dimos/hardware/sensors/camera/webcam.py index 51199624fe..cfd1a080a0 100644 --- a/dimos/hardware/sensors/camera/webcam.py +++ b/dimos/hardware/sensors/camera/webcam.py @@ -23,8 +23,8 @@ from reactivex.observable import Observable from dimos.hardware.sensors.camera.spec import CameraConfig, CameraHardware -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.reactive import backpressure diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index 2df9afd70c..214b1f73e3 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -33,8 +33,10 @@ DepthCameraConfig, DepthCameraHardware, ) -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.foxglove_bridge import FoxgloveBridge diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/compat.py similarity index 97% rename from dimos/hardware/sensors/camera/zed/__init__.py rename to dimos/hardware/sensors/camera/zed/compat.py index 6e3b905e90..3cec8d9566 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/compat.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ZED camera hardware interfaces.""" +"""ZED camera compatibility layer and SDK detection.""" from pathlib import Path diff --git a/dimos/hardware/sensors/camera/zed/test_zed.py b/dimos/hardware/sensors/camera/zed/test_zed.py index 2716e809a5..a98055a355 100644 --- a/dimos/hardware/sensors/camera/zed/test_zed.py +++ b/dimos/hardware/sensors/camera/zed/test_zed.py @@ -15,7 +15,7 @@ import pytest -from dimos.hardware.sensors.camera import zed +from dimos.hardware.sensors.camera.zed import compat as zed from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index ca5014337b..16e85aa93c 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -27,12 +27,12 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.protocol.tf.tf import TF from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger(level=logging.INFO) @@ -278,7 +278,9 @@ def _publish_pose(self, msg) -> None: # type: ignore[no-untyped-def] # Publish TF transform from world to camera import time - from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.msgs.geometry_msgs.Vector3 import Vector3 transform = Transform( translation=Vector3(*msg.position), diff --git a/dimos/hardware/sensors/lidar/__init__.py b/dimos/hardware/sensors/lidar/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/sensors/lidar/fastlio2/__init__.py b/dimos/hardware/sensors/lidar/fastlio2/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/sensors/lidar/livox/__init__.py b/dimos/hardware/sensors/lidar/livox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/manipulation/__init__.py b/dimos/manipulation/__init__.py deleted file mode 100644 index d2a511d146..0000000000 --- a/dimos/manipulation/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulation module for robot arm motion planning and control.""" - -from dimos.manipulation.manipulation_module import ( - ManipulationModule, - ManipulationModuleConfig, - ManipulationState, - manipulation_module, -) -from dimos.manipulation.pick_and_place_module import ( - PickAndPlaceModule, - PickAndPlaceModuleConfig, - pick_and_place_module, -) - -__all__ = [ - "ManipulationModule", - "ManipulationModuleConfig", - "ManipulationState", - "PickAndPlaceModule", - "PickAndPlaceModuleConfig", - "manipulation_module", - "pick_and_place_module", -] diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 7a0eefb37a..8ef2c03279 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -35,12 +35,15 @@ from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.realsense import realsense_camera +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera from dimos.manipulation.manipulation_module import manipulation_module from dimos.manipulation.pick_and_place_module import pick_and_place_module -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import object_scene_registration_module from dimos.robot.foxglove_bridge import foxglove_bridge # TODO: migrate to rerun from dimos.utils.data import get_data diff --git a/dimos/manipulation/control/__init__.py b/dimos/manipulation/control/__init__.py deleted file mode 100644 index ec85660eb3..0000000000 --- a/dimos/manipulation/control/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Control Modules - -Hardware-agnostic controllers for robotic manipulation tasks. - -Submodules: -- servo_control: Real-time servo-level controllers (Cartesian motion control) -- trajectory_controller: Trajectory planning and execution -""" - -# Re-export from servo_control for backwards compatibility -from dimos.manipulation.control.servo_control import ( - CartesianMotionController, - CartesianMotionControllerConfig, - cartesian_motion_controller, -) - -# Re-export from trajectory_controller -from dimos.manipulation.control.trajectory_controller import ( - JointTrajectoryController, - JointTrajectoryControllerConfig, - joint_trajectory_controller, -) - -__all__ = [ - # Servo control - "CartesianMotionController", - "CartesianMotionControllerConfig", - # Trajectory control - "JointTrajectoryController", - "JointTrajectoryControllerConfig", - "cartesian_motion_controller", - "joint_trajectory_controller", -] diff --git a/dimos/manipulation/control/coordinator_client.py b/dimos/manipulation/control/coordinator_client.py index cbaad28df2..dfa99371a6 100644 --- a/dimos/manipulation/control/coordinator_client.py +++ b/dimos/manipulation/control/coordinator_client.py @@ -54,7 +54,7 @@ ) if TYPE_CHECKING: - from dimos.msgs.trajectory_msgs import JointTrajectory + from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class CoordinatorClient: diff --git a/dimos/manipulation/control/dual_trajectory_setter.py b/dimos/manipulation/control/dual_trajectory_setter.py index 05793eeb76..3fdccea400 100644 --- a/dimos/manipulation/control/dual_trajectory_setter.py +++ b/dimos/manipulation/control/dual_trajectory_setter.py @@ -37,8 +37,8 @@ from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, ) -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory @dataclass diff --git a/dimos/manipulation/control/servo_control/__init__.py b/dimos/manipulation/control/servo_control/__init__.py deleted file mode 100644 index 5418a7e24b..0000000000 --- a/dimos/manipulation/control/servo_control/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Servo Control Modules - -Real-time servo-level controllers for robotic manipulation. -Includes Cartesian motion control with PID-based tracking. -""" - -from dimos.manipulation.control.servo_control.cartesian_motion_controller import ( - CartesianMotionController, - CartesianMotionControllerConfig, - cartesian_motion_controller, -) - -__all__ = [ - "CartesianMotionController", - "CartesianMotionControllerConfig", - "cartesian_motion_controller", -] diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py index a12fb44a96..0cbd41e218 100644 --- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py +++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py @@ -34,8 +34,14 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Twist, Vector3 -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState from dimos.utils.logging_config import setup_logger from dimos.utils.simple_controller import PIDController diff --git a/dimos/manipulation/control/target_setter.py b/dimos/manipulation/control/target_setter.py index f54a6af2f0..a0228c6a24 100644 --- a/dimos/manipulation/control/target_setter.py +++ b/dimos/manipulation/control/target_setter.py @@ -25,7 +25,9 @@ import time from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class TargetSetter: diff --git a/dimos/manipulation/control/trajectory_controller/__init__.py b/dimos/manipulation/control/trajectory_controller/__init__.py deleted file mode 100644 index fb4360d4cc..0000000000 --- a/dimos/manipulation/control/trajectory_controller/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory Controller Module - -Joint-space trajectory execution for robotic manipulators. -""" - -from dimos.manipulation.control.trajectory_controller.joint_trajectory_controller import ( - JointTrajectoryController, - JointTrajectoryControllerConfig, - joint_trajectory_controller, -) - -__all__ = [ - "JointTrajectoryController", - "JointTrajectoryControllerConfig", - "joint_trajectory_controller", -] diff --git a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py index ed62a7345e..465df7afea 100644 --- a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py +++ b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py @@ -36,8 +36,11 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState, TrajectoryStatus +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState, TrajectoryStatus from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/manipulation/control/trajectory_controller/spec.py b/dimos/manipulation/control/trajectory_controller/spec.py index e11da91847..b696f2dc6a 100644 --- a/dimos/manipulation/control/trajectory_controller/spec.py +++ b/dimos/manipulation/control/trajectory_controller/spec.py @@ -30,8 +30,11 @@ if TYPE_CHECKING: from dimos.core.stream import In, Out - from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState - from dimos.msgs.trajectory_msgs import JointTrajectory as JointTrajectoryMsg, TrajectoryState + from dimos.msgs.sensor_msgs.JointCommand import JointCommand + from dimos.msgs.sensor_msgs.JointState import JointState + from dimos.msgs.sensor_msgs.RobotState import RobotState + from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory as JointTrajectoryMsg + from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState # Input topics joint_state: In[JointState] | None = None # Feedback from arm driver diff --git a/dimos/manipulation/control/trajectory_setter.py b/dimos/manipulation/control/trajectory_setter.py index a5baa512b5..25f9db2a3f 100644 --- a/dimos/manipulation/control/trajectory_setter.py +++ b/dimos/manipulation/control/trajectory_setter.py @@ -36,8 +36,8 @@ from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, ) -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class TrajectorySetter: diff --git a/dimos/manipulation/grasping/__init__.py b/dimos/manipulation/grasping/__init__.py deleted file mode 100644 index 41779f55e7..0000000000 --- a/dimos/manipulation/grasping/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dimos.manipulation.grasping.graspgen_module import ( - GraspGenConfig, - GraspGenModule, - graspgen, -) -from dimos.manipulation.grasping.grasping import ( - GraspingModule, - grasping_module, -) - -__all__ = [ - "GraspGenConfig", - "GraspGenModule", - "GraspingModule", - "graspgen", - "grasping_module", -] diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 01e34f905f..43a6c9a20a 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -16,8 +16,8 @@ from dimos.agents.agent import agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense import realsense_camera -from dimos.manipulation.grasping import graspgen +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera +from dimos.manipulation.grasping.graspgen_module import graspgen from dimos.manipulation.grasping.grasping import grasping_module from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import object_scene_registration_module diff --git a/dimos/manipulation/grasping/graspgen_module.py b/dimos/manipulation/grasping/graspgen_module.py index 7ec8cfeeaa..c883126840 100644 --- a/dimos/manipulation/grasping/graspgen_module.py +++ b/dimos/manipulation/grasping/graspgen_module.py @@ -25,13 +25,13 @@ from dimos.core.docker_runner import DockerModuleConfig from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseArray -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.PoseArray import PoseArray +from dimos.msgs.std_msgs.Header import Header from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import matrix_to_pose if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() diff --git a/dimos/manipulation/grasping/grasping.py b/dimos/manipulation/grasping/grasping.py index 433a07d846..ef05dc29e2 100644 --- a/dimos/manipulation/grasping/grasping.py +++ b/dimos/manipulation/grasping/grasping.py @@ -25,12 +25,12 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseArray +from dimos.msgs.geometry_msgs.PoseArray import PoseArray from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import quaternion_to_euler if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index f064130965..fe5561c705 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -34,23 +34,20 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In -from dimos.manipulation.planning import ( - JointPath, +from dimos.manipulation.planning.factory import create_kinematics, create_planner +from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import JointPath, Obstacle, RobotName, WorldRobotID +from dimos.manipulation.planning.spec.protocols import KinematicsSpec, PlannerSpec +from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, - KinematicsSpec, - Obstacle, - ObstacleType, - PlannerSpec, - RobotModelConfig, - RobotName, - WorldRobotID, - create_kinematics, - create_planner, ) -from dimos.manipulation.planning.monitor import WorldMonitor -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -247,7 +244,7 @@ def _on_joint_state(self, msg: JointState) -> None: def _tf_publish_loop(self) -> None: """Publish TF transforms at 10Hz for EE and extra links.""" - from dimos.msgs.geometry_msgs import Transform + from dimos.msgs.geometry_msgs.Transform import Transform period = 0.1 # 10Hz while not self._tf_stop_event.is_set(): @@ -406,7 +403,7 @@ def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool: return self._fail("No joint state") # Convert Pose to PoseStamped for the IK solver - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped target_pose = PoseStamped( frame_id="world", @@ -750,7 +747,7 @@ def add_obstacle( return "" # Import PoseStamped here to avoid circular imports - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped obstacle = Obstacle( name=name, diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index 6d6ad1042e..b433df6801 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -37,7 +37,9 @@ ManipulationModule, ManipulationModuleConfig, ) -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.perception.detection.type.detection3d.object import ( Object as DetObject, ) @@ -45,8 +47,8 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import PoseArray - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.geometry_msgs.PoseArray import PoseArray + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() diff --git a/dimos/manipulation/planning/__init__.py b/dimos/manipulation/planning/__init__.py deleted file mode 100644 index 8aaf0caa25..0000000000 --- a/dimos/manipulation/planning/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Planning Module - -Motion planning stack for robotic manipulators using Protocol-based architecture. - -## Architecture - -- WorldSpec: Core backend owning physics/collision (DrakeWorld, future: MuJoCoWorld) -- KinematicsSpec: IK solvers - - JacobianIK: Backend-agnostic iterative/differential IK - - DrakeOptimizationIK: Drake-specific nonlinear optimization IK -- PlannerSpec: Backend-agnostic joint-space path planning - - RRTConnectPlanner: Bi-directional RRT-Connect - - RRTStarPlanner: RRT* (asymptotically optimal) - -## Factory Functions - -Use factory functions to create components: - -```python -from dimos.manipulation.planning.factory import ( - create_world, - create_kinematics, - create_planner, -) - -world = create_world(backend="drake", enable_viz=True) -kinematics = create_kinematics(name="jacobian") # or "drake_optimization" -planner = create_planner(name="rrt_connect") # backend-agnostic -``` - -## Monitors - -Use WorldMonitor for reactive state synchronization: - -```python -from dimos.manipulation.planning.monitor import WorldMonitor - -monitor = WorldMonitor(enable_viz=True) -robot_id = monitor.add_robot(config) -monitor.finalize() -monitor.start_state_monitor(robot_id) -``` -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "factory": ["create_kinematics", "create_planner", "create_planning_stack", "create_world"], - "spec": [ - "CollisionObjectMessage", - "IKResult", - "IKStatus", - "JointPath", - "KinematicsSpec", - "Obstacle", - "ObstacleType", - "PlannerSpec", - "PlanningResult", - "PlanningStatus", - "RobotModelConfig", - "RobotName", - "WorldRobotID", - "WorldSpec", - ], - "trajectory_generator.joint_trajectory_generator": ["JointTrajectoryGenerator"], - }, -) diff --git a/dimos/manipulation/planning/examples/__init__.py b/dimos/manipulation/planning/examples/__init__.py deleted file mode 100644 index 7971835dab..0000000000 --- a/dimos/manipulation/planning/examples/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation planning examples. -""" diff --git a/dimos/manipulation/planning/examples/manipulation_client.py b/dimos/manipulation/planning/examples/manipulation_client.py index ac098ac52a..4dcd2fe9e8 100644 --- a/dimos/manipulation/planning/examples/manipulation_client.py +++ b/dimos/manipulation/planning/examples/manipulation_client.py @@ -49,7 +49,9 @@ from dimos.core.rpc_client import RPCClient from dimos.manipulation.manipulation_module import ManipulationModule -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 _client = RPCClient(None, ManipulationModule) @@ -71,7 +73,7 @@ def state() -> str: def plan(target_joints: list[float], robot_name: str | None = None) -> bool: """Plan to joint configuration. e.g. plan([0.1]*7)""" - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState js = JointState(position=target_joints) return _client.plan_to_joints(js, robot_name) @@ -106,7 +108,7 @@ def execute(robot_name: str | None = None) -> bool: def home(robot_name: str | None = None) -> bool: """Plan and execute move to home position.""" - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState home_joints = _client.get_robot_info(robot_name).get("home_joints", [0.0] * 7) success = _client.plan_to_joints(JointState(position=home_joints), robot_name) diff --git a/dimos/manipulation/planning/factory.py b/dimos/manipulation/planning/factory.py index d392bac563..65173dfd18 100644 --- a/dimos/manipulation/planning/factory.py +++ b/dimos/manipulation/planning/factory.py @@ -19,11 +19,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from dimos.manipulation.planning.spec import ( - KinematicsSpec, - PlannerSpec, - WorldSpec, - ) + from dimos.manipulation.planning.spec.protocols import KinematicsSpec, PlannerSpec, WorldSpec def create_world( diff --git a/dimos/manipulation/planning/kinematics/__init__.py b/dimos/manipulation/planning/kinematics/__init__.py deleted file mode 100644 index dacd2007cb..0000000000 --- a/dimos/manipulation/planning/kinematics/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Kinematics Module - -Contains IK solver implementations that use WorldSpec. - -## Implementations - -- JacobianIK: Backend-agnostic iterative/differential IK (works with any WorldSpec) -- DrakeOptimizationIK: Drake-specific nonlinear optimization IK (requires DrakeWorld) - -## Usage - -Use factory functions to create IK solvers: - -```python -from dimos.manipulation.planning.factory import create_kinematics - -# Backend-agnostic (works with any WorldSpec) -kinematics = create_kinematics(name="jacobian") - -# Drake-specific (requires DrakeWorld, more accurate) -kinematics = create_kinematics(name="drake_optimization") - -result = kinematics.solve(world, robot_id, target_pose) -``` -""" - -from dimos.manipulation.planning.kinematics.drake_optimization_ik import ( - DrakeOptimizationIK, -) -from dimos.manipulation.planning.kinematics.jacobian_ik import JacobianIK -from dimos.manipulation.planning.kinematics.pinocchio_ik import ( - PinocchioIK, - PinocchioIKConfig, -) - -__all__ = ["DrakeOptimizationIK", "JacobianIK", "PinocchioIK", "PinocchioIKConfig"] diff --git a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py index 1e6b1962a5..b13aa8947a 100644 --- a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py +++ b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py @@ -20,10 +20,13 @@ import numpy as np -from dimos.manipulation.planning.spec import IKResult, IKStatus, WorldRobotID, WorldSpec +from dimos.manipulation.planning.spec.enums import IKStatus +from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import compute_pose_error -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import pose_to_matrix diff --git a/dimos/manipulation/planning/kinematics/jacobian_ik.py b/dimos/manipulation/planning/kinematics/jacobian_ik.py index c756045d36..fb493e2d5f 100644 --- a/dimos/manipulation/planning/kinematics/jacobian_ik.py +++ b/dimos/manipulation/planning/kinematics/jacobian_ik.py @@ -28,7 +28,9 @@ import numpy as np -from dimos.manipulation.planning.spec import IKResult, IKStatus, WorldRobotID, WorldSpec +from dimos.manipulation.planning.spec.enums import IKStatus +from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import ( check_singularity, compute_error_twist, @@ -41,8 +43,11 @@ if TYPE_CHECKING: from numpy.typing import NDArray -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState logger = setup_logger() diff --git a/dimos/manipulation/planning/kinematics/pinocchio_ik.py b/dimos/manipulation/planning/kinematics/pinocchio_ik.py index ff1c2dcc2a..cb6ee91608 100644 --- a/dimos/manipulation/planning/kinematics/pinocchio_ik.py +++ b/dimos/manipulation/planning/kinematics/pinocchio_ik.py @@ -44,7 +44,8 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() diff --git a/dimos/manipulation/planning/monitor/__init__.py b/dimos/manipulation/planning/monitor/__init__.py deleted file mode 100644 index c280bd4d56..0000000000 --- a/dimos/manipulation/planning/monitor/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -World Monitor Module - -Provides reactive monitoring for keeping WorldSpec synchronized with the real world. - -## Components - -- WorldMonitor: Top-level monitor using WorldSpec Protocol -- WorldStateMonitor: Syncs joint state to WorldSpec -- WorldObstacleMonitor: Syncs obstacles to WorldSpec - -All monitors use the factory pattern and Protocol types. - -## Example - -```python -from dimos.manipulation.planning.monitor import WorldMonitor - -monitor = WorldMonitor(enable_viz=True) -robot_id = monitor.add_robot(config) -monitor.finalize() - -# Start monitoring -monitor.start_state_monitor(robot_id) -monitor.start_obstacle_monitor() - -# Handle joint state messages -monitor.on_joint_state(msg, robot_id) - -# Thread-safe collision checking -is_valid = monitor.is_state_valid(robot_id, q_test) -``` -""" - -from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor -from dimos.manipulation.planning.monitor.world_obstacle_monitor import ( - WorldObstacleMonitor, -) -from dimos.manipulation.planning.monitor.world_state_monitor import WorldStateMonitor - -# Re-export message types from spec for convenience -from dimos.manipulation.planning.spec import CollisionObjectMessage - -__all__ = [ - "CollisionObjectMessage", - "WorldMonitor", - "WorldObstacleMonitor", - "WorldStateMonitor", -] diff --git a/dimos/manipulation/planning/monitor/world_monitor.py b/dimos/manipulation/planning/monitor/world_monitor.py index cca2dda013..32f519dfd4 100644 --- a/dimos/manipulation/planning/monitor/world_monitor.py +++ b/dimos/manipulation/planning/monitor/world_monitor.py @@ -23,8 +23,8 @@ from dimos.manipulation.planning.factory import create_world from dimos.manipulation.planning.monitor.world_obstacle_monitor import WorldObstacleMonitor from dimos.manipulation.planning.monitor.world_state_monitor import WorldStateMonitor -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -33,15 +33,15 @@ import numpy as np from numpy.typing import NDArray - from dimos.manipulation.planning.spec import ( + from dimos.manipulation.planning.spec.config import RobotModelConfig + from dimos.manipulation.planning.spec.models import ( CollisionObjectMessage, JointPath, Obstacle, - RobotModelConfig, WorldRobotID, - WorldSpec, ) - from dimos.msgs.vision_msgs import Detection3D + from dimos.manipulation.planning.spec.protocols import WorldSpec + from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() @@ -366,7 +366,7 @@ def get_link_pose( link_name: Name of the link in the URDF joint_state: Joint state to use (uses current if None) """ - from dimos.msgs.geometry_msgs import Quaternion + from dimos.msgs.geometry_msgs.Quaternion import Quaternion with self._world.scratch_context() as ctx: if joint_state is None: diff --git a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py index 4f69afad68..a21ee68726 100644 --- a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py +++ b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py @@ -29,20 +29,17 @@ import time from typing import TYPE_CHECKING, Any -from dimos.manipulation.planning.spec import ( - CollisionObjectMessage, - Obstacle, - ObstacleType, -) -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import CollisionObjectMessage, Obstacle +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from collections.abc import Callable import threading - from dimos.manipulation.planning.spec import WorldSpec - from dimos.msgs.vision_msgs import Detection3D + from dimos.manipulation.planning.spec.protocols import WorldSpec + from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() diff --git a/dimos/manipulation/planning/monitor/world_state_monitor.py b/dimos/manipulation/planning/monitor/world_state_monitor.py index 87d61bb66f..8548251c73 100644 --- a/dimos/manipulation/planning/monitor/world_state_monitor.py +++ b/dimos/manipulation/planning/monitor/world_state_monitor.py @@ -31,7 +31,7 @@ import numpy as np -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -40,7 +40,7 @@ from numpy.typing import NDArray - from dimos.manipulation.planning.spec import WorldSpec + from dimos.manipulation.planning.spec.protocols import WorldSpec logger = setup_logger() diff --git a/dimos/manipulation/planning/planners/__init__.py b/dimos/manipulation/planning/planners/__init__.py deleted file mode 100644 index 8fb8ae042b..0000000000 --- a/dimos/manipulation/planning/planners/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Motion Planners Module - -Contains motion planning implementations that use WorldSpec. - -All planners are backend-agnostic - they only use WorldSpec methods and -work with any physics backend (Drake, MuJoCo, PyBullet, etc.). - -## Implementations - -- RRTConnectPlanner: Bi-directional RRT-Connect planner (fast, reliable) - -## Usage - -Use factory functions to create planners: - -```python -from dimos.manipulation.planning.factory import create_planner - -planner = create_planner(name="rrt_connect") # Returns PlannerSpec -result = planner.plan_joint_path(world, robot_id, q_start, q_goal) -``` -""" - -from dimos.manipulation.planning.planners.rrt_planner import RRTConnectPlanner - -__all__ = ["RRTConnectPlanner"] diff --git a/dimos/manipulation/planning/planners/rrt_planner.py b/dimos/manipulation/planning/planners/rrt_planner.py index 71204488c4..7f308dce0c 100644 --- a/dimos/manipulation/planning/planners/rrt_planner.py +++ b/dimos/manipulation/planning/planners/rrt_planner.py @@ -26,15 +26,11 @@ import numpy as np -from dimos.manipulation.planning.spec import ( - JointPath, - PlanningResult, - PlanningStatus, - WorldRobotID, - WorldSpec, -) +from dimos.manipulation.planning.spec.enums import PlanningStatus +from dimos.manipulation.planning.spec.models import JointPath, PlanningResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.path_utils import compute_path_length -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: diff --git a/dimos/manipulation/planning/spec/__init__.py b/dimos/manipulation/planning/spec/__init__.py deleted file mode 100644 index a78fb6e5fd..0000000000 --- a/dimos/manipulation/planning/spec/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulation Planning Specifications.""" - -from dimos.manipulation.planning.spec.config import RobotModelConfig -from dimos.manipulation.planning.spec.enums import IKStatus, ObstacleType, PlanningStatus -from dimos.manipulation.planning.spec.protocols import ( - KinematicsSpec, - PlannerSpec, - WorldSpec, -) -from dimos.manipulation.planning.spec.types import ( - CollisionObjectMessage, - IKResult, - Jacobian, - JointPath, - Obstacle, - PlanningResult, - RobotName, - WorldRobotID, -) - -__all__ = [ - "CollisionObjectMessage", - "IKResult", - "IKStatus", - "Jacobian", - "JointPath", - "KinematicsSpec", - "Obstacle", - "ObstacleType", - "PlannerSpec", - "PlanningResult", - "PlanningStatus", - "RobotModelConfig", - "RobotName", - "WorldRobotID", - "WorldSpec", -] diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index e379fc1eb5..80cf248f08 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -22,7 +22,7 @@ from pydantic import Field from dimos.core.module import ModuleConfig -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped class RobotModelConfig(ModuleConfig): diff --git a/dimos/manipulation/planning/spec/types.py b/dimos/manipulation/planning/spec/models.py similarity index 97% rename from dimos/manipulation/planning/spec/types.py rename to dimos/manipulation/planning/spec/models.py index 2683db7814..37daa331e4 100644 --- a/dimos/manipulation/planning/spec/types.py +++ b/dimos/manipulation/planning/spec/models.py @@ -29,8 +29,8 @@ import numpy as np from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import PoseStamped - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.sensor_msgs.JointState import JointState RobotName: TypeAlias = str diff --git a/dimos/manipulation/planning/spec/protocols.py b/dimos/manipulation/planning/spec/protocols.py index dea4718abb..76ecd1780b 100644 --- a/dimos/manipulation/planning/spec/protocols.py +++ b/dimos/manipulation/planning/spec/protocols.py @@ -29,15 +29,15 @@ from numpy.typing import NDArray from dimos.manipulation.planning.spec.config import RobotModelConfig - from dimos.manipulation.planning.spec.types import ( + from dimos.manipulation.planning.spec.models import ( IKResult, JointPath, Obstacle, PlanningResult, WorldRobotID, ) - from dimos.msgs.geometry_msgs import PoseStamped - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.sensor_msgs.JointState import JointState @runtime_checkable diff --git a/dimos/manipulation/planning/trajectory_generator/__init__.py b/dimos/manipulation/planning/trajectory_generator/__init__.py deleted file mode 100644 index a7449cf45f..0000000000 --- a/dimos/manipulation/planning/trajectory_generator/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory Generator Module - -Generates time-parameterized trajectories from waypoints. -""" - -from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( - JointTrajectoryGenerator, -) - -__all__ = ["JointTrajectoryGenerator"] diff --git a/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py b/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py index 6b732d133c..1ac6b74351 100644 --- a/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py +++ b/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py @@ -32,7 +32,8 @@ import math -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint class JointTrajectoryGenerator: diff --git a/dimos/manipulation/planning/trajectory_generator/spec.py b/dimos/manipulation/planning/trajectory_generator/spec.py index 5357679f28..0814f5dc0b 100644 --- a/dimos/manipulation/planning/trajectory_generator/spec.py +++ b/dimos/manipulation/planning/trajectory_generator/spec.py @@ -35,7 +35,7 @@ from typing import Protocol -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class JointTrajectoryGeneratorSpec(Protocol): diff --git a/dimos/manipulation/planning/utils/__init__.py b/dimos/manipulation/planning/utils/__init__.py deleted file mode 100644 index 04ec1806b5..0000000000 --- a/dimos/manipulation/planning/utils/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Planning Utilities - -Standalone utility functions for kinematics and path operations. -These are extracted from the old ABC base classes to enable composition over inheritance. - -## Modules - -- kinematics_utils: Jacobian operations, singularity detection, pose error computation -- path_utils: Path interpolation, simplification, length computation -""" - -from dimos.manipulation.planning.utils.kinematics_utils import ( - check_singularity, - compute_error_twist, - compute_pose_error, - damped_pseudoinverse, - get_manipulability, -) -from dimos.manipulation.planning.utils.path_utils import ( - compute_path_length, - interpolate_path, - interpolate_segment, -) - -__all__ = [ - # Kinematics utilities - "check_singularity", - "compute_error_twist", - # Path utilities - "compute_path_length", - "compute_pose_error", - "damped_pseudoinverse", - "get_manipulability", - "interpolate_path", - "interpolate_segment", -] diff --git a/dimos/manipulation/planning/utils/kinematics_utils.py b/dimos/manipulation/planning/utils/kinematics_utils.py index c9f3f95a3d..02e885f1ae 100644 --- a/dimos/manipulation/planning/utils/kinematics_utils.py +++ b/dimos/manipulation/planning/utils/kinematics_utils.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.manipulation.planning.spec import Jacobian + from dimos.manipulation.planning.spec.models import Jacobian def damped_pseudoinverse( diff --git a/dimos/manipulation/planning/utils/path_utils.py b/dimos/manipulation/planning/utils/path_utils.py index fbf8af4032..dd5de1a0a4 100644 --- a/dimos/manipulation/planning/utils/path_utils.py +++ b/dimos/manipulation/planning/utils/path_utils.py @@ -32,12 +32,13 @@ import numpy as np -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.manipulation.planning.spec import JointPath, WorldRobotID, WorldSpec + from dimos.manipulation.planning.spec.models import JointPath, WorldRobotID + from dimos.manipulation.planning.spec.protocols import WorldSpec def interpolate_path( diff --git a/dimos/manipulation/planning/world/__init__.py b/dimos/manipulation/planning/world/__init__.py deleted file mode 100644 index 8ddef7fdff..0000000000 --- a/dimos/manipulation/planning/world/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -World Module - -Contains world implementations that own the physics/collision backend. - -## Implementations - -- DrakeWorld: Uses Drake MultibodyPlant + SceneGraph -""" - -from dimos.manipulation.planning.world.drake_world import DrakeWorld - -__all__ = ["DrakeWorld"] diff --git a/dimos/manipulation/planning/world/drake_world.py b/dimos/manipulation/planning/world/drake_world.py index 147e1e3ad3..ce155253ca 100644 --- a/dimos/manipulation/planning/world/drake_world.py +++ b/dimos/manipulation/planning/world/drake_world.py @@ -25,14 +25,10 @@ import numpy as np -from dimos.manipulation.planning.spec import ( - JointPath, - Obstacle, - ObstacleType, - RobotModelConfig, - WorldRobotID, - WorldSpec, -) +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import JointPath, Obstacle, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.mesh_utils import prepare_urdf_for_drake from dimos.utils.logging_config import setup_logger @@ -41,8 +37,9 @@ from numpy.typing import NDArray -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.JointState import JointState try: from pydrake.geometry import ( # type: ignore[import-not-found] diff --git a/dimos/manipulation/test_manipulation_module.py b/dimos/manipulation/test_manipulation_module.py index c30ba9b55c..46a196e28c 100644 --- a/dimos/manipulation/test_manipulation_module.py +++ b/dimos/manipulation/test_manipulation_module.py @@ -30,9 +30,12 @@ ManipulationModule, ManipulationState, ) -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.data import get_data diff --git a/dimos/manipulation/test_manipulation_unit.py b/dimos/manipulation/test_manipulation_unit.py index cfd6e35fda..67ca8332b4 100644 --- a/dimos/manipulation/test_manipulation_unit.py +++ b/dimos/manipulation/test_manipulation_unit.py @@ -26,9 +26,12 @@ ManipulationModule, ManipulationState, ) -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint @pytest.fixture diff --git a/dimos/mapping/__init__.py b/dimos/mapping/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 75b674b2a0..06bf493564 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -26,8 +26,8 @@ HeightCostConfig, OccupancyConfig, ) -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/mapping/google_maps/google_maps.py b/dimos/mapping/google_maps/google_maps.py index 7f5ce32e99..18a1e25e2b 100644 --- a/dimos/mapping/google_maps/google_maps.py +++ b/dimos/mapping/google_maps/google_maps.py @@ -16,14 +16,14 @@ import googlemaps # type: ignore[import-untyped] -from dimos.mapping.google_maps.types import ( +from dimos.mapping.google_maps.models import ( Coordinates, LocationContext, NearbyPlace, PlacePosition, Position, ) -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.utils.logging_config import setup_logger diff --git a/dimos/mapping/google_maps/types.py b/dimos/mapping/google_maps/models.py similarity index 100% rename from dimos/mapping/google_maps/types.py rename to dimos/mapping/google_maps/models.py diff --git a/dimos/mapping/google_maps/test_google_maps.py b/dimos/mapping/google_maps/test_google_maps.py index 13f7fa8eaa..2805f5589c 100644 --- a/dimos/mapping/google_maps/test_google_maps.py +++ b/dimos/mapping/google_maps/test_google_maps.py @@ -13,7 +13,7 @@ # limitations under the License. -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon def test_get_position(maps_client, maps_fixture) -> None: diff --git a/dimos/mapping/types.py b/dimos/mapping/models.py similarity index 100% rename from dimos/mapping/types.py rename to dimos/mapping/models.py diff --git a/dimos/mapping/occupancy/path_mask.py b/dimos/mapping/occupancy/path_mask.py index 5ad3010111..7744ab95ba 100644 --- a/dimos/mapping/occupancy/path_mask.py +++ b/dimos/mapping/occupancy/path_mask.py @@ -16,8 +16,8 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path def make_path_mask( diff --git a/dimos/mapping/occupancy/path_resampling.py b/dimos/mapping/occupancy/path_resampling.py index 2090bf8f04..4d957a1aad 100644 --- a/dimos/mapping/occupancy/path_resampling.py +++ b/dimos/mapping/occupancy/path_resampling.py @@ -18,8 +18,11 @@ import numpy as np from scipy.ndimage import uniform_filter1d # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Vector3 -from dimos.msgs.nav_msgs import Path +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion diff --git a/dimos/mapping/occupancy/test_path_mask.py b/dimos/mapping/occupancy/test_path_mask.py index dede997946..f566af2a23 100644 --- a/dimos/mapping/occupancy/test_path_mask.py +++ b/dimos/mapping/occupancy/test_path_mask.py @@ -19,9 +19,9 @@ from dimos.mapping.occupancy.path_mask import make_path_mask from dimos.mapping.occupancy.path_resampling import smooth_resample_path from dimos.mapping.occupancy.visualizations import visualize_occupancy_grid -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar from dimos.utils.data import get_data diff --git a/dimos/mapping/occupancy/test_path_resampling.py b/dimos/mapping/occupancy/test_path_resampling.py index c23f71cf89..aeda7d11ad 100644 --- a/dimos/mapping/occupancy/test_path_resampling.py +++ b/dimos/mapping/occupancy/test_path_resampling.py @@ -18,7 +18,7 @@ from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.path_resampling import simple_resample_path, smooth_resample_path from dimos.mapping.occupancy.visualize_path import visualize_path -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid from dimos.msgs.sensor_msgs.Image import Image diff --git a/dimos/mapping/occupancy/visualizations.py b/dimos/mapping/occupancy/visualizations.py index 2ed0364257..36321896be 100644 --- a/dimos/mapping/occupancy/visualizations.py +++ b/dimos/mapping/occupancy/visualizations.py @@ -19,8 +19,8 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.sensor_msgs.Image import Image, ImageFormat Palette: TypeAlias = Literal["rainbow", "turbo"] diff --git a/dimos/mapping/occupancy/visualize_path.py b/dimos/mapping/occupancy/visualize_path.py index 0662582f72..89dcf83067 100644 --- a/dimos/mapping/occupancy/visualize_path.py +++ b/dimos/mapping/occupancy/visualize_path.py @@ -16,8 +16,8 @@ import numpy as np from dimos.mapping.occupancy.visualizations import visualize_occupancy_grid -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.sensor_msgs.Image import Image, ImageFormat diff --git a/dimos/mapping/osm/__init__.py b/dimos/mapping/osm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index 832116e25c..4cfeddc9b8 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -16,9 +16,9 @@ from PIL import Image as PILImage, ImageDraw +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import MapImage, get_osm_map from dimos.mapping.osm.query import query_for_one_position, query_for_one_position_and_context -from dimos.mapping.types import LatLon from dimos.models.vl.base import VlModel from dimos.utils.logging_config import setup_logger diff --git a/dimos/mapping/osm/osm.py b/dimos/mapping/osm/osm.py index 31fb044087..f9b7eaafda 100644 --- a/dimos/mapping/osm/osm.py +++ b/dimos/mapping/osm/osm.py @@ -21,8 +21,8 @@ from PIL import Image as PILImage import requests # type: ignore[import-untyped] -from dimos.mapping.types import ImageCoord, LatLon -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.mapping.models import ImageCoord, LatLon +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat @dataclass(frozen=True) diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index 17fbfe3d4b..7a3c3b0154 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -15,8 +15,8 @@ import re from typing import Any +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import MapImage -from dimos.mapping.types import LatLon from dimos.models.vl.base import VlModel from dimos.utils.generic import extract_json_from_llm_response from dimos.utils.logging_config import setup_logger diff --git a/dimos/mapping/osm/test_osm.py b/dimos/mapping/osm/test_osm.py index 475e2b40fc..64fbb72b02 100644 --- a/dimos/mapping/osm/test_osm.py +++ b/dimos/mapping/osm/test_osm.py @@ -21,8 +21,8 @@ from requests import Request import requests_mock +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import get_osm_map -from dimos.mapping.types import LatLon from dimos.utils.data import get_data _fixture_dir = get_data("osm_map_test") diff --git a/dimos/mapping/pointclouds/demo.py b/dimos/mapping/pointclouds/demo.py index 5251fc3406..2812aaae42 100644 --- a/dimos/mapping/pointclouds/demo.py +++ b/dimos/mapping/pointclouds/demo.py @@ -25,8 +25,8 @@ read_pointcloud, visualize, ) -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data app = typer.Typer() diff --git a/dimos/mapping/pointclouds/occupancy.py b/dimos/mapping/pointclouds/occupancy.py index 0f6ad8c0de..c9cd7e7af3 100644 --- a/dimos/mapping/pointclouds/occupancy.py +++ b/dimos/mapping/pointclouds/occupancy.py @@ -21,7 +21,7 @@ import numpy as np from scipy import ndimage # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid if TYPE_CHECKING: @@ -99,7 +99,7 @@ def _simple_occupancy_kernel( if TYPE_CHECKING: from collections.abc import Callable - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 @dataclass(frozen=True) diff --git a/dimos/mapping/pointclouds/test_occupancy.py b/dimos/mapping/pointclouds/test_occupancy.py index d265800f24..93b5793dc8 100644 --- a/dimos/mapping/pointclouds/test_occupancy.py +++ b/dimos/mapping/pointclouds/test_occupancy.py @@ -26,8 +26,8 @@ ) from dimos.mapping.pointclouds.util import read_pointcloud from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment from dimos.utils.testing.test_moment import Go2Moment diff --git a/dimos/mapping/pointclouds/test_occupancy_speed.py b/dimos/mapping/pointclouds/test_occupancy_speed.py index 2def839dd5..ac4085e971 100644 --- a/dimos/mapping/pointclouds/test_occupancy_speed.py +++ b/dimos/mapping/pointclouds/test_occupancy_speed.py @@ -21,7 +21,7 @@ from dimos.mapping.voxels import VoxelGridMapper from dimos.utils.cli.plot import bar from dimos.utils.data import get_data, get_data_dir -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.mark.tool diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index 95e70e1d6d..bb5f4ed764 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -20,7 +20,7 @@ from dimos.core.transport import LCMTransport from dimos.mapping.voxels import VoxelGridMapper -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment from dimos.utils.testing.replay import TimedSensorReplay diff --git a/dimos/mapping/utils/distance.py b/dimos/mapping/utils/distance.py index 6e8c48c205..42b8a9be04 100644 --- a/dimos/mapping/utils/distance.py +++ b/dimos/mapping/utils/distance.py @@ -14,7 +14,7 @@ import math -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon def distance_in_meters(location1: LatLon, location2: LatLon) -> float: diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index c2078dc309..e4e03dfc01 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -25,8 +25,8 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.utils.decorators import simple_mcache +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger from dimos.utils.reactive import backpressure diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index e09e069f05..be73d01ac1 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -26,9 +26,8 @@ from dimos.core.stream import In from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.models.embedding.clip import CLIPModel -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.utils.reactive import getter_hot diff --git a/dimos/memory/test_embedding.py b/dimos/memory/test_embedding.py index b7e7fbb294..9a59ed51e1 100644 --- a/dimos/memory/test_embedding.py +++ b/dimos/memory/test_embedding.py @@ -15,9 +15,9 @@ import pytest from dimos.memory.embedding import EmbeddingMemory, SpatialEntry -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay dir_name = "unitree_go2_bigoffice" diff --git a/dimos/memory/timeseries/__init__.py b/dimos/memory/timeseries/__init__.py deleted file mode 100644 index debc14ab3a..0000000000 --- a/dimos/memory/timeseries/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Time series storage and replay.""" - -from dimos.memory.timeseries.base import TimeSeriesStore -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteStore - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - if name == "PostgresStore": - from dimos.memory.timeseries.postgres import PostgresStore - - return PostgresStore - if name == "reset_db": - from dimos.memory.timeseries.postgres import reset_db - - return reset_db - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = [ - "InMemoryStore", - "PickleDirStore", - "PostgresStore", - "SqliteStore", - "TimeSeriesStore", - "reset_db", -] diff --git a/dimos/models/__init__.py b/dimos/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/models/base.py b/dimos/models/base.py index d03ce5c539..fd18d8ba93 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -22,7 +22,7 @@ import torch from dimos.core.resource import Resource -from dimos.protocol.service import BaseConfig, Configurable +from dimos.protocol.service.spec import BaseConfig, Configurable # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] diff --git a/dimos/models/embedding/__init__.py b/dimos/models/embedding/__init__.py deleted file mode 100644 index 050d35467e..0000000000 --- a/dimos/models/embedding/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from dimos.models.embedding.base import Embedding, EmbeddingModel - -__all__ = [ - "Embedding", - "EmbeddingModel", -] - -# Optional: CLIP support -try: - from dimos.models.embedding.clip import CLIPModel - - __all__.append("CLIPModel") -except ImportError: - pass - -# Optional: MobileCLIP support -try: - from dimos.models.embedding.mobileclip import MobileCLIPModel - - __all__.append("MobileCLIPModel") -except ImportError: - pass - -# Optional: TorchReID support -try: - from dimos.models.embedding.treid import TorchReIDModel - - __all__.append("TorchReIDModel") -except ImportError: - pass diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 520818aabf..0c80cafc0a 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -25,7 +25,7 @@ from dimos.types.timestamped import Timestamped if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image class EmbeddingModelConfig(LocalModelConfig): diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index e3a61e9570..6fb42b7ccf 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -21,7 +21,7 @@ from dimos.models.base import HuggingFaceModel from dimos.models.embedding.base import Embedding, EmbeddingModel, HuggingFaceEmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image class CLIPModelConfig(HuggingFaceEmbeddingModelConfig): diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index 8ad37936be..84bba74829 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -22,7 +22,7 @@ from dimos.models.base import LocalModel from dimos.models.embedding.base import Embedding, EmbeddingModel, EmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/embedding/test_embedding.py b/dimos/models/embedding/test_embedding.py index 466c974b32..20aac83dbb 100644 --- a/dimos/models/embedding/test_embedding.py +++ b/dimos/models/embedding/test_embedding.py @@ -7,7 +7,7 @@ from dimos.models.embedding.clip import CLIPModel from dimos.models.embedding.mobileclip import MobileCLIPModel from dimos.models.embedding.treid import TorchReIDModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 69cc1aae13..21a4527781 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -24,7 +24,7 @@ from dimos.models.base import LocalModel from dimos.models.embedding.base import Embedding, EmbeddingModel, EmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index 54158b2b92..e9744f6d81 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -28,9 +28,9 @@ from PIL import Image as PILImage import torch -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection2d.seg import Detection2DSeg from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py deleted file mode 100644 index 482a907cbd..0000000000 --- a/dimos/models/vl/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "base": ["Captioner", "VlModel"], - "florence": ["Florence2Model"], - "moondream": ["MoondreamVlModel"], - "moondream_hosted": ["MoondreamHostedVlModel"], - "openai": ["OpenAIVlModel"], - "qwen": ["QwenVlModel"], - }, -) diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 1cdeb3f92f..08b83fc503 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -8,11 +8,13 @@ import warnings from dimos.core.resource import Resource -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint from dimos.protocol.service.spec import BaseConfig, Configurable from dimos.utils.data import get_data -from dimos.utils.decorators import retry +from dimos.utils.decorators.decorators import retry from dimos.utils.llm_utils import extract_json if sys.version_info < (3, 13): @@ -73,7 +75,7 @@ def vlm_detection_to_detection2d( Detection2DBBox instance or None if invalid """ # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import Detection2DBBox + from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox # Validate list/tuple structure if not isinstance(vlm_detection, (list, tuple)): @@ -130,7 +132,7 @@ def vlm_point_to_detection2d_point( Returns: Detection2DPoint instance or None if invalid """ - from dimos.perception.detection.type import Detection2DPoint + from dimos.perception.detection.type.detection2d.point import Detection2DPoint # Validate list/tuple structure if not isinstance(vlm_point, (list, tuple)): @@ -260,7 +262,7 @@ def query_detections( self, image: Image, query: str, **kwargs: Any ) -> ImageDetections2D[Detection2DBBox]: # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import ImageDetections2D + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D full_query = f"""show me bounding boxes in pixels for this query: `{query}` @@ -321,7 +323,7 @@ def query_points( ImageDetections2D containing Detection2DPoint instances """ # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import ImageDetections2D + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D full_query = f"""Show me point coordinates in pixels for this query: `{query}` diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 2e6cf822a8..b68441328a 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -20,7 +20,7 @@ from dimos.models.base import HuggingFaceModel from dimos.models.vl.base import Captioner -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image class Florence2Model(HuggingFaceModel, Captioner): diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index c444d8b9ed..0f5e501ef6 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -9,8 +9,10 @@ from dimos.models.base import HuggingFaceModel, HuggingFaceModelConfig from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint # Moondream works well with 512x512 max MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index 57df91b47e..aad9fe514c 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -7,8 +7,10 @@ from PIL import Image as PILImage from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint class Config(VlModelConfig): diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index ec774189e4..0486bbdb30 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -6,7 +6,7 @@ from openai import OpenAI from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 014c6f73a5..202ce6759e 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -6,7 +6,7 @@ from openai import OpenAI from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image class QwenVlModelConfig(VlModelConfig): diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py index 0cc5c90d0e..b0b03e70fa 100644 --- a/dimos/models/vl/test_base.py +++ b/dimos/models/vl/test_base.py @@ -6,8 +6,8 @@ from dimos.core.transport import LCMTransport from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data # Captured actual response from Qwen API for cafe.jpg with query "humans" diff --git a/dimos/models/vl/test_captioner.py b/dimos/models/vl/test_captioner.py index c7ebb8fc63..734c83290e 100644 --- a/dimos/models/vl/test_captioner.py +++ b/dimos/models/vl/test_captioner.py @@ -6,7 +6,7 @@ from dimos.models.vl.florence import Florence2Model from dimos.models.vl.moondream import MoondreamVlModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index 43dad0ef94..f0fd3b8d5a 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -11,8 +11,8 @@ from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.cli.plot import bar from dimos.utils.data import get_data @@ -228,7 +228,7 @@ def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: @pytest.mark.slow def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: """Test query_batch optimization - multiple images, same query.""" - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay # Load 5 frames at 1-second intervals using TimedSensorReplay replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") @@ -285,7 +285,7 @@ def test_vlm_resize( sizes: list[tuple[int, int] | None], ) -> None: """Test VLM auto_resize effect on performance.""" - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") image = replay.find_closest_seek(0).to_rgb() diff --git a/dimos/msgs/__init__.py b/dimos/msgs/__init__.py deleted file mode 100644 index 4395dbcc51..0000000000 --- a/dimos/msgs/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dimos.msgs.helpers import resolve_msg_type -from dimos.msgs.protocol import DimosMsg - -__all__ = ["DimosMsg", "resolve_msg_type"] diff --git a/dimos/msgs/foxglove_msgs/__init__.py b/dimos/msgs/foxglove_msgs/__init__.py deleted file mode 100644 index 945ebf94c9..0000000000 --- a/dimos/msgs/foxglove_msgs/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations - -__all__ = ["ImageAnnotations"] diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 5f50f9b9d1..9b08c8dadd 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -29,7 +29,7 @@ from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.std_msgs import Header +from dimos.msgs.std_msgs.Header import Header from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py deleted file mode 100644 index 01069d765c..0000000000 --- a/dimos/msgs/geometry_msgs/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -from dimos.msgs.geometry_msgs.Point import Point -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike, to_pose -from dimos.msgs.geometry_msgs.PoseArray import PoseArray -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance -from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped -from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance -from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike -from dimos.msgs.geometry_msgs.Wrench import Wrench -from dimos.msgs.geometry_msgs.WrenchStamped import WrenchStamped - -__all__ = [ - "Point", - "PointStamped", - "Pose", - "PoseArray", - "PoseLike", - "PoseStamped", - "PoseWithCovariance", - "PoseWithCovarianceStamped", - "Quaternion", - "Transform", - "Twist", - "TwistStamped", - "TwistWithCovariance", - "TwistWithCovarianceStamped", - "Vector3", - "VectorLike", - "Wrench", - "WrenchStamped", - "to_pose", -] diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py index 82250a9113..a486f33303 100644 --- a/dimos/msgs/geometry_msgs/test_PoseStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -15,7 +15,7 @@ import pickle import time -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped def test_lcm_encode_decode() -> None: diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index 0c15610b05..056238719a 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -18,7 +18,11 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def test_transform_initialization() -> None: diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py index df4bd8b6a2..a4dc93f3cc 100644 --- a/dimos/msgs/geometry_msgs/test_Twist.py +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -15,7 +15,9 @@ from dimos_lcm.geometry_msgs import Twist as LCMTwist import numpy as np -from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def test_twist_initialization() -> None: diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py index b3d2324af0..01c5cf7842 100644 --- a/dimos/msgs/geometry_msgs/test_publish.py +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -17,7 +17,7 @@ import lcm import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 @pytest.mark.tool diff --git a/dimos/msgs/helpers.py b/dimos/msgs/helpers.py index 8464ec4ab1..91466f7fdd 100644 --- a/dimos/msgs/helpers.py +++ b/dimos/msgs/helpers.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from dimos.msgs import DimosMsg + from dimos.msgs.protocol import DimosMsg @lru_cache(maxsize=256) @@ -38,7 +38,10 @@ def resolve_msg_type(type_name: str) -> type[DimosMsg] | None: return None # Try different import paths + # First try the direct submodule path (e.g., dimos.msgs.geometry_msgs.Quaternion) + # then fall back to parent package (for dimos_lcm or other packages) import_paths = [ + f"dimos.msgs.{module_name}.{class_name}", f"dimos.msgs.{module_name}", f"dimos_lcm.{module_name}", ] diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py index d45e1b6232..4760884620 100644 --- a/dimos/msgs/nav_msgs/OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -28,7 +28,8 @@ import numpy as np from PIL import Image -from dimos.msgs.geometry_msgs import Pose, Vector3, VectorLike +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py deleted file mode 100644 index 9d099068ad..0000000000 --- a/dimos/msgs/nav_msgs/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from dimos.msgs.nav_msgs.OccupancyGrid import ( # type: ignore[attr-defined] - CostValues, - MapMetaData, - OccupancyGrid, -) -from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.nav_msgs.Path import Path - -__all__ = ["CostValues", "MapMetaData", "OccupancyGrid", "Odometry", "Path"] diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py index d1ec8938b4..7aae8abfac 100644 --- a/dimos/msgs/nav_msgs/test_OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -23,9 +23,9 @@ from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.inflation import simple_inflate from dimos.mapping.pointclouds.occupancy import general_occupancy -from dimos.msgs.geometry_msgs import Pose -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data diff --git a/dimos/msgs/sensor_msgs/Imu.py b/dimos/msgs/sensor_msgs/Imu.py index 7fe03ce03f..f3461975ff 100644 --- a/dimos/msgs/sensor_msgs/Imu.py +++ b/dimos/msgs/sensor_msgs/Imu.py @@ -18,7 +18,8 @@ from dimos_lcm.sensor_msgs.Imu import Imu as LCMImu -from dimos.msgs.geometry_msgs import Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index 22fe731a70..67af1c5ac3 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -28,7 +28,8 @@ import open3d as o3d # type: ignore[import-untyped] import open3d.core as o3c # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Transform, Vector3 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.types.timestamped import Timestamped if TYPE_CHECKING: diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py deleted file mode 100644 index 7fec2d2793..0000000000 --- a/dimos/msgs/sensor_msgs/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.msgs.sensor_msgs.Imu import Imu -from dimos.msgs.sensor_msgs.JointCommand import JointCommand -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.msgs.sensor_msgs.Joy import Joy -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.msgs.sensor_msgs.RobotState import RobotState - -__all__ = [ - "CameraInfo", - "Image", - "ImageFormat", - "Imu", - "JointCommand", - "JointState", - "Joy", - "PointCloud2", - "RobotState", -] diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py index f48802ab7a..70e6e35aec 100644 --- a/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -16,9 +16,9 @@ import numpy as np -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay def test_lcm_encode_decode() -> None: diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 24375139b3..cc2fc9f096 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -18,7 +18,7 @@ from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.fixture diff --git a/dimos/msgs/std_msgs/__init__.py b/dimos/msgs/std_msgs/__init__.py deleted file mode 100644 index ae8e3dd8f6..0000000000 --- a/dimos/msgs/std_msgs/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .Bool import Bool -from .Header import Header -from .Int8 import Int8 -from .Int32 import Int32 -from .UInt32 import UInt32 - -__all__ = ["Bool", "Header", "Int8", "Int32", "UInt32"] diff --git a/dimos/msgs/std_msgs/test_header.py b/dimos/msgs/std_msgs/test_header.py index 93f20da283..29f4ee2c0e 100644 --- a/dimos/msgs/std_msgs/test_header.py +++ b/dimos/msgs/std_msgs/test_header.py @@ -15,7 +15,7 @@ from datetime import datetime import time -from dimos.msgs.std_msgs import Header +from dimos.msgs.std_msgs.Header import Header def test_header_initialization_methods() -> None: diff --git a/dimos/msgs/tf2_msgs/__init__.py b/dimos/msgs/tf2_msgs/__init__.py deleted file mode 100644 index 69d4e0137e..0000000000 --- a/dimos/msgs/tf2_msgs/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.msgs.tf2_msgs.TFMessage import TFMessage - -__all__ = ["TFMessage"] diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py index 8567de9988..c379481f1d 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -14,8 +14,10 @@ from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage def test_tfmessage_initialization() -> None: diff --git a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py index 8b58a61a44..2a03b7ee71 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py @@ -16,8 +16,10 @@ import pytest -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic diff --git a/dimos/msgs/trajectory_msgs/__init__.py b/dimos/msgs/trajectory_msgs/__init__.py deleted file mode 100644 index 44039e594e..0000000000 --- a/dimos/msgs/trajectory_msgs/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory message types. - -Similar to ROS trajectory_msgs package. -""" - -from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory -from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint -from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState, TrajectoryStatus - -__all__ = [ - "JointTrajectory", - "TrajectoryPoint", - "TrajectoryState", - "TrajectoryStatus", -] diff --git a/dimos/msgs/vision_msgs/__init__.py b/dimos/msgs/vision_msgs/__init__.py deleted file mode 100644 index 0f1c9c8dc1..0000000000 --- a/dimos/msgs/vision_msgs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .BoundingBox2DArray import BoundingBox2DArray -from .BoundingBox3DArray import BoundingBox3DArray -from .Detection2D import Detection2D -from .Detection2DArray import Detection2DArray -from .Detection3D import Detection3D -from .Detection3DArray import Detection3DArray - -__all__ = [ - "BoundingBox2DArray", - "BoundingBox3DArray", - "Detection2D", - "Detection2DArray", - "Detection3D", - "Detection3DArray", -] diff --git a/dimos/navigation/base.py b/dimos/navigation/base.py index 347c4ad124..1530308711 100644 --- a/dimos/navigation/base.py +++ b/dimos/navigation/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from enum import Enum -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped class NavigationState(Enum): diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index 170bff9bcd..c96ba9efad 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -20,8 +20,10 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.DEBUG) diff --git a/dimos/navigation/demo_ros_navigation.py b/dimos/navigation/demo_ros_navigation.py index 4d57867d59..0efa04cd44 100644 --- a/dimos/navigation/demo_ros_navigation.py +++ b/dimos/navigation/demo_ros_navigation.py @@ -15,7 +15,9 @@ import time from dimos.core.module_coordinator import ModuleCoordinator -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.navigation import rosnav from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py deleted file mode 100644 index 24ce957ccf..0000000000 --- a/dimos/navigation/frontier_exploration/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer - -__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 419986780a..834897d396 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -17,8 +17,8 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) @@ -56,7 +56,7 @@ def quick_costmap(): # One obstacle grid[9:10, 9:10] = CostValues.OCCUPIED - from dimos.msgs.geometry_msgs import Pose + from dimos.msgs.geometry_msgs.Pose import Pose origin = Pose() origin.position.x = -1.0 @@ -97,7 +97,7 @@ def create_test_costmap(width: int = 40, height: int = 40, resolution: float = 0 grid[13:14, 18:22] = CostValues.OCCUPIED # Top corridor obstacle # Create origin at bottom-left, adjusted for map size - from dimos.msgs.geometry_msgs import Pose + from dimos.msgs.geometry_msgs.Pose import Pose origin = Pose() # Center the map around (0, 0) in world coordinates diff --git a/dimos/navigation/frontier_exploration/utils.py b/dimos/navigation/frontier_exploration/utils.py index 28644cdd41..d5ed7df61c 100644 --- a/dimos/navigation/frontier_exploration/utils.py +++ b/dimos/navigation/frontier_exploration/utils.py @@ -19,8 +19,8 @@ import numpy as np from PIL import Image, ImageDraw -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image.Image: diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index f8a5436fc1..20fab41b35 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -34,8 +34,9 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.occupancy.inflation import simple_inflate -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import get_distance diff --git a/dimos/navigation/replanning_a_star/controllers.py b/dimos/navigation/replanning_a_star/controllers.py index 865aafb8be..07ba8c7119 100644 --- a/dimos/navigation/replanning_a_star/controllers.py +++ b/dimos/navigation/replanning_a_star/controllers.py @@ -19,8 +19,9 @@ from numpy.typing import NDArray from dimos.core.global_config import GlobalConfig -from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.trigonometry import angle_diff diff --git a/dimos/navigation/replanning_a_star/global_planner.py b/dimos/navigation/replanning_a_star/global_planner.py index df2680a4a7..4c4e79cb7b 100644 --- a/dimos/navigation/replanning_a_star/global_planner.py +++ b/dimos/navigation/replanning_a_star/global_planner.py @@ -23,8 +23,8 @@ from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource from dimos.mapping.occupancy.path_resampling import smooth_resample_path -from dimos.msgs.geometry_msgs import Twist from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.msgs.nav_msgs.Path import Path diff --git a/dimos/navigation/replanning_a_star/goal_validator.py b/dimos/navigation/replanning_a_star/goal_validator.py index 5cd093e955..b717c76295 100644 --- a/dimos/navigation/replanning_a_star/goal_validator.py +++ b/dimos/navigation/replanning_a_star/goal_validator.py @@ -16,8 +16,8 @@ import numpy as np -from dimos.msgs.geometry_msgs import Vector3, VectorLike -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid def find_safe_goal( diff --git a/dimos/navigation/replanning_a_star/local_planner.py b/dimos/navigation/replanning_a_star/local_planner.py index a5f8d9e457..d50d0def84 100644 --- a/dimos/navigation/replanning_a_star/local_planner.py +++ b/dimos/navigation/replanning_a_star/local_planner.py @@ -23,9 +23,10 @@ from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Twist from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationState from dimos.navigation.replanning_a_star.controllers import Controller, PController from dimos.navigation.replanning_a_star.navigation_map import NavigationMap diff --git a/dimos/navigation/replanning_a_star/min_cost_astar.py b/dimos/navigation/replanning_a_star/min_cost_astar.py index c3430e64d9..55f502680c 100644 --- a/dimos/navigation/replanning_a_star/min_cost_astar.py +++ b/dimos/navigation/replanning_a_star/min_cost_astar.py @@ -14,8 +14,11 @@ import heapq -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import VectorLike +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger # Try to import C++ extension for faster pathfinding diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 28a22a2a86..796390f06c 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -21,8 +21,11 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PointStamped, PoseStamped, Twist -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner diff --git a/dimos/navigation/replanning_a_star/path_clearance.py b/dimos/navigation/replanning_a_star/path_clearance.py index e99fba26c3..7dc08d49e0 100644 --- a/dimos/navigation/replanning_a_star/path_clearance.py +++ b/dimos/navigation/replanning_a_star/path_clearance.py @@ -19,8 +19,8 @@ from dimos.core.global_config import GlobalConfig from dimos.mapping.occupancy.path_mask import make_path_mask -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path class PathClearance: diff --git a/dimos/navigation/replanning_a_star/path_distancer.py b/dimos/navigation/replanning_a_star/path_distancer.py index 04d844267f..c50583ca33 100644 --- a/dimos/navigation/replanning_a_star/path_distancer.py +++ b/dimos/navigation/replanning_a_star/path_distancer.py @@ -17,7 +17,7 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path +from dimos.msgs.nav_msgs.Path import Path class PathDistancer: diff --git a/dimos/navigation/replanning_a_star/test_goal_validator.py b/dimos/navigation/replanning_a_star/test_goal_validator.py index 4cda9de863..69c7147696 100644 --- a/dimos/navigation/replanning_a_star/test_goal_validator.py +++ b/dimos/navigation/replanning_a_star/test_goal_validator.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.navigation.replanning_a_star.goal_validator import find_safe_goal from dimos.utils.data import get_data diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 230d94b50f..38c8e32847 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -27,26 +27,29 @@ from reactivex import operators as ops from reactivex.subject import Subject -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, ROSTransport -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - TwistStamped, - Vector3, -) -from dimos.msgs.nav_msgs import Path -from dimos.msgs.sensor_msgs import Joy, PointCloud2 -from dimos.msgs.std_msgs import Bool, Int8 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Bool import Bool +from dimos.msgs.std_msgs.Int8 import Int8 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.navigation.base import NavigationInterface, NavigationState +from dimos.spec.control import LocalPlanner +from dimos.spec.mapping import GlobalPointcloud +from dimos.spec.nav import Nav +from dimos.spec.perception import Pointcloud from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion @@ -64,10 +67,10 @@ class Config(ModuleConfig): class ROSNav( Module[Config], NavigationInterface, - spec.Nav, - spec.GlobalPointcloud, - spec.Pointcloud, - spec.LocalPlanner, + Nav, + GlobalPointcloud, + Pointcloud, + LocalPlanner, ): default_config = Config diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 0c84e8ac34..0693ca5dd1 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -16,7 +16,7 @@ from dimos.models.qwen.bbox import BBox from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.generic import extract_json_from_llm_response diff --git a/dimos/navigation/visual_servoing/detection_navigation.py b/dimos/navigation/visual_servoing/detection_navigation.py index 5f89bd1faa..351883e8ac 100644 --- a/dimos/navigation/visual_servoing/detection_navigation.py +++ b/dimos/navigation/visual_servoing/detection_navigation.py @@ -15,11 +15,15 @@ from dimos_lcm.sensor_msgs import CameraInfo as DimosLcmCameraInfo import numpy as np -from dimos.msgs.geometry_msgs import Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox -from dimos.perception.detection.type.detection3d import Detection3DPC -from dimos.protocol.tf import LCMTF +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.protocol.tf.tf import LCMTF from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/navigation/visual_servoing/visual_servoing_2d.py b/dimos/navigation/visual_servoing/visual_servoing_2d.py index 032b5f3370..f424b21466 100644 --- a/dimos/navigation/visual_servoing/visual_servoing_2d.py +++ b/dimos/navigation/visual_servoing/visual_servoing_2d.py @@ -14,8 +14,9 @@ import numpy as np -from dimos.msgs.geometry_msgs import Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo class VisualServoing2D: diff --git a/dimos/perception/__init__.py b/dimos/perception/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/perception/common/__init__.py b/dimos/perception/common/__init__.py deleted file mode 100644 index 5902f54bb8..0000000000 --- a/dimos/perception/common/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -from .utils import ( - BoundingBox2D, - CameraInfo, - Detection2D, - Detection3D, - Header, - Image, - ObjectData, - Pose, - Quaternion, - Union, - Vector, - Vector3, - bbox2d_to_corners, - colorize_depth, - combine_object_data, - cp, - cv2, - detection_results_to_object_data, - draw_bounding_box, - draw_object_detection_visualization, - draw_segmentation_mask, - extract_pose_from_detection3d, - find_clicked_detection, - load_camera_info, - load_camera_info_opencv, - logger, - np, - point_in_bbox, - project_2d_points_to_3d, - project_2d_points_to_3d_cpu, - project_2d_points_to_3d_cuda, - project_3d_points_to_2d, - project_3d_points_to_2d_cpu, - project_3d_points_to_2d_cuda, - rectify_image, - setup_logger, - torch, - yaml, -) - -__all__ = [ - "BoundingBox2D", - "CameraInfo", - "Detection2D", - "Detection3D", - "Header", - "Image", - "ObjectData", - "Pose", - "Quaternion", - "Union", - "Vector", - "Vector3", - "bbox2d_to_corners", - "colorize_depth", - "combine_object_data", - "cp", - "cv2", - "detection_results_to_object_data", - "draw_bounding_box", - "draw_object_detection_visualization", - "draw_segmentation_mask", - "extract_pose_from_detection3d", - "find_clicked_detection", - "load_camera_info", - "load_camera_info_opencv", - "logger", - "np", - "point_in_bbox", - "project_2d_points_to_3d", - "project_2d_points_to_3d_cpu", - "project_2d_points_to_3d_cuda", - "project_3d_points_to_2d", - "project_3d_points_to_2d_cpu", - "project_3d_points_to_2d_cuda", - "rectify_image", - "setup_logger", - "torch", - "yaml", -] diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index c5f550ade3..1670d31998 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -25,9 +25,11 @@ import torch import yaml # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.std_msgs.Header import Header from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index ad98d0474a..cdb09d359e 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -15,8 +15,8 @@ from dimos.agents.agent import agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense import realsense_camera -from dimos.hardware.sensors.camera.zed import zed_camera +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera +from dimos.hardware.sensors.camera.zed.compat import zed_camera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import object_scene_registration_module from dimos.robot.foxglove_bridge import foxglove_bridge diff --git a/dimos/perception/detection/__init__.py b/dimos/perception/detection/__init__.py deleted file mode 100644 index ae9f8cb14d..0000000000 --- a/dimos/perception/detection/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "detectors": ["Detector", "Yolo2DDetector"], - "module2D": ["Detection2DModule"], - "module3D": ["Detection3DModule"], - }, -) diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 8c1a65eb8b..5f8f1bc4b9 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -23,23 +23,23 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.module3D import Detection3DModule from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.perception.detection.type import ( - Detection2D, - Detection3DPC, - ImageDetections2D, - ImageDetections3DPC, -) -from dimos.protocol.tf import TF +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.protocol.tf.tf import TF from dimos.robot.unitree.go2 import connection from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay class Moment(TypedDict, total=False): @@ -203,7 +203,7 @@ def detection3dpc(detections3dpc) -> Detection3DPC: @pytest.fixture(scope="session") def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: - from dimos.perception.detection.detectors import Yolo2DDetector + from dimos.perception.detection.detectors.yolo import Yolo2DDetector c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) @@ -262,7 +262,7 @@ def moment_provider(**kwargs) -> Moment3D: @pytest.fixture(scope="session") def object_db_module(get_moment): """Create and populate an ObjectDBModule with detections from multiple frames.""" - from dimos.perception.detection.detectors import Yolo2DDetector + from dimos.perception.detection.detectors.yolo import Yolo2DDetector c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) diff --git a/dimos/perception/detection/detectors/__init__.py b/dimos/perception/detection/detectors/__init__.py deleted file mode 100644 index 2f151fe3ef..0000000000 --- a/dimos/perception/detection/detectors/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# from dimos.perception.detection.detectors.detic import Detic2DDetector -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.detectors.yolo import Yolo2DDetector - -__all__ = [ - "Detector", - "Yolo2DDetector", -] diff --git a/dimos/perception/detection/detectors/types.py b/dimos/perception/detection/detectors/base.py similarity index 84% rename from dimos/perception/detection/detectors/types.py rename to dimos/perception/detection/detectors/base.py index e85c5ae18e..40aa82e5bd 100644 --- a/dimos/perception/detection/detectors/types.py +++ b/dimos/perception/detection/detectors/base.py @@ -14,8 +14,8 @@ from abc import ABC, abstractmethod -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D class Detector(ABC): diff --git a/dimos/perception/detection/detectors/conftest.py b/dimos/perception/detection/detectors/conftest.py index 6a2c041a8b..bb9a47e0eb 100644 --- a/dimos/perception/detection/detectors/conftest.py +++ b/dimos/perception/detection/detectors/conftest.py @@ -14,7 +14,7 @@ import pytest -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.detectors.yolo import Yolo2DDetector from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode diff --git a/dimos/perception/detection/detectors/person/test_person_detectors.py b/dimos/perception/detection/detectors/person/test_person_detectors.py index 2ed7cdc7dc..6130e5888a 100644 --- a/dimos/perception/detection/detectors/person/test_person_detectors.py +++ b/dimos/perception/detection/detectors/person/test_person_detectors.py @@ -14,7 +14,8 @@ import pytest -from dimos.perception.detection.type import Detection2DPerson, ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.person import Detection2DPerson @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/detectors/person/yolo.py b/dimos/perception/detection/detectors/person/yolo.py index 519f45f2f6..26d68a4510 100644 --- a/dimos/perception/detection/detectors/person/yolo.py +++ b/dimos/perception/detection/detectors/person/yolo.py @@ -14,9 +14,9 @@ from ultralytics import YOLO # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/detection/detectors/test_bbox_detectors.py b/dimos/perception/detection/detectors/test_bbox_detectors.py index 2e69016eb5..c8112e9aab 100644 --- a/dimos/perception/detection/detectors/test_bbox_detectors.py +++ b/dimos/perception/detection/detectors/test_bbox_detectors.py @@ -17,8 +17,9 @@ from reactivex.disposable import CompositeDisposable from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2D, ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D @pytest.fixture(params=["bbox_detector", "person_detector", "yoloe_detector"], scope="session") diff --git a/dimos/perception/detection/detectors/yolo.py b/dimos/perception/detection/detectors/yolo.py index c9a65a120e..64565cce7a 100644 --- a/dimos/perception/detection/detectors/yolo.py +++ b/dimos/perception/detection/detectors/yolo.py @@ -14,9 +14,9 @@ from ultralytics import YOLO # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/detection/detectors/yoloe.py b/dimos/perception/detection/detectors/yoloe.py index 9c9881209c..536dd9f497 100644 --- a/dimos/perception/detection/detectors/yoloe.py +++ b/dimos/perception/detection/detectors/yoloe.py @@ -20,9 +20,9 @@ from numpy.typing import NDArray from ultralytics import YOLOE # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 0a07b1238d..b6d0c9358c 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -22,18 +22,20 @@ from reactivex.observable import Observable from reactivex.subject import Subject -from dimos import spec from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.detectors import Detector # type: ignore[attr-defined] +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.perception.detection.detectors.base import Detector from dimos.perception.detection.detectors.yolo import Yolo2DDetector -from dimos.perception.detection.type import Filter2D, ImageDetections2D +from dimos.perception.detection.type.detection2d.base import Filter2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.spec.perception import Camera from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure @@ -158,7 +160,7 @@ def stop(self) -> None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - camera: spec.Camera, + camera: Camera, prefix: str = "/detector2d", **kwargs, ) -> Detection2DModule: diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 96ae4e8297..fa392dc799 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -22,19 +22,23 @@ from reactivex import operators as ops from reactivex.observable import Observable -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D -from dimos.perception.detection.type.detection3d import Detection3DPC from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.spec.perception import Camera, Pointcloud from dimos.types.timestamped import align_timestamped from dimos.utils.reactive import backpressure @@ -177,7 +181,7 @@ def detection2d_to_3d(args): # type: ignore[no-untyped-def] transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) return self.process_frame(detections, pc, transform) - self.detection_stream_3d = align_timestamped( + self.detection_stream_3d = align_timestamped( # type: ignore[type-var] backpressure(self.detection_stream_2d()), self.pointcloud.observable(), # type: ignore[no-untyped-call] match_tolerance=0.25, @@ -203,8 +207,8 @@ def _publish_detections(self, detections: ImageDetections3DPC) -> None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - lidar: spec.Pointcloud, - camera: spec.Camera, + lidar: Pointcloud, + camera: Camera, prefix: str = "/detector3d", **kwargs, ) -> "ModuleProxy": diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index bc0a346a59..5672786b94 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -25,12 +25,16 @@ from dimos.core.core import rpc from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.type.detection3d import Detection3DPC from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC from dimos.perception.detection.type.utils import TableStr diff --git a/dimos/perception/detection/objectDB.py b/dimos/perception/detection/objectDB.py index 9af8058c55..5b73e97742 100644 --- a/dimos/perception/detection/objectDB.py +++ b/dimos/perception/detection/objectDB.py @@ -20,11 +20,11 @@ import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py index 913043f312..9dbba210a2 100644 --- a/dimos/perception/detection/person_tracker.py +++ b/dimos/perception/detection/person_tracker.py @@ -21,10 +21,13 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.types.timestamped import align_timestamped from dimos.utils.reactive import backpressure diff --git a/dimos/perception/detection/reid/__init__.py b/dimos/perception/detection/reid/__init__.py deleted file mode 100644 index 31d50a894b..0000000000 --- a/dimos/perception/detection/reid/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem -from dimos.perception.detection.reid.module import Config, ReidModule -from dimos.perception.detection.reid.type import IDSystem, PassthroughIDSystem - -__all__ = [ - "Config", - "EmbeddingIDSystem", - # ID Systems - "IDSystem", - "PassthroughIDSystem", - # Module - "ReidModule", -] diff --git a/dimos/perception/detection/reid/embedding_id_system.py b/dimos/perception/detection/reid/embedding_id_system.py index 15bb491f5c..faf322de07 100644 --- a/dimos/perception/detection/reid/embedding_id_system.py +++ b/dimos/perception/detection/reid/embedding_id_system.py @@ -19,7 +19,7 @@ from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.perception.detection.reid.type import IDSystem -from dimos.perception.detection.type import Detection2DBBox +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox class EmbeddingIDSystem(IDSystem): diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py index 0a359746d3..2bb0ecfbb2 100644 --- a/dimos/perception/detection/reid/module.py +++ b/dimos/perception/detection/reid/module.py @@ -24,8 +24,8 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.perception.detection.reid.type import IDSystem from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D @@ -48,7 +48,7 @@ def __init__(self, idsystem: IDSystem | None = None, **kwargs) -> None: # type: super().__init__(**kwargs) if idsystem is None: try: - from dimos.models.embedding import TorchReIDModel + from dimos.models.embedding.treid import TorchReIDModel idsystem = EmbeddingIDSystem(model=TorchReIDModel, padding=0) # type: ignore[arg-type] except Exception as e: diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py index cc8632627f..2916c9040d 100644 --- a/dimos/perception/detection/reid/test_embedding_id_system.py +++ b/dimos/perception/detection/reid/test_embedding_id_system.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.utils.data import get_data diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index f5672c1f67..aac6ba11d1 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -15,7 +15,7 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.perception.detection.reid.module import ReidModule @@ -23,7 +23,7 @@ @pytest.mark.tool def test_reid_ingress(imageDetections2d) -> None: try: - from dimos.models.embedding import TorchReIDModel + from dimos.models.embedding.treid import TorchReIDModel except Exception: pytest.skip("TorchReIDModel not available") diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py deleted file mode 100644 index b14464d4fa..0000000000 --- a/dimos/perception/detection/type/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "detection2d.base": [ - "Detection2D", - "Filter2D", - ], - "detection2d.bbox": [ - "Detection2DBBox", - ], - "detection2d.person": [ - "Detection2DPerson", - ], - "detection2d.point": [ - "Detection2DPoint", - ], - "detection2d.imageDetections2D": [ - "ImageDetections2D", - ], - "detection3d": [ - "Detection3D", - "Detection3DBBox", - "Detection3DPC", - "ImageDetections3DPC", - "PointCloudFilter", - "height_filter", - "radius_outlier", - "raycast", - "statistical", - ], - "imageDetections": ["ImageDetections"], - "utils": ["TableStr"], - }, -) diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py index ee9374af8c..ef05813118 100644 --- a/dimos/perception/detection/type/detection2d/base.py +++ b/dimos/perception/detection/type/detection2d/base.py @@ -17,8 +17,8 @@ from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import Timestamped diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 45dc848e9d..9ce3f11b96 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -22,7 +22,7 @@ from typing_extensions import Self from ultralytics.engine.results import Results # type: ignore[import-not-found] - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos_lcm.foxglove_msgs.ImageAnnotations import ( PointsAnnotation, @@ -40,9 +40,9 @@ from rich.console import Console from rich.text import Text -from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.types.timestamped import to_ros_stamp, to_timestamp from dimos.utils.decorators.decorators import simple_mcache diff --git a/dimos/perception/detection/type/detection2d/imageDetections2D.py b/dimos/perception/detection/type/detection2d/imageDetections2D.py index 34033a9c50..507125c333 100644 --- a/dimos/perception/detection/type/detection2d/imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/imageDetections2D.py @@ -27,8 +27,8 @@ if TYPE_CHECKING: from ultralytics.engine.results import Results - from dimos.msgs.sensor_msgs import Image - from dimos.msgs.vision_msgs import Detection2DArray + from dimos.msgs.sensor_msgs.Image import Image + from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray T2D = TypeVar("T2D", bound=Detection2D, default=Detection2DBBox) diff --git a/dimos/perception/detection/type/detection2d/person.py b/dimos/perception/detection/type/detection2d/person.py index efb12ebdbc..e85229719a 100644 --- a/dimos/perception/detection/type/detection2d/person.py +++ b/dimos/perception/detection/type/detection2d/person.py @@ -25,7 +25,7 @@ import numpy as np from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.type.detection2d.bbox import Bbox, Detection2DBBox from dimos.types.timestamped import to_ros_stamp from dimos.utils.decorators.decorators import simple_mcache diff --git a/dimos/perception/detection/type/detection2d/point.py b/dimos/perception/detection/type/detection2d/point.py index 216ec57b82..0155bcb9cd 100644 --- a/dimos/perception/detection/type/detection2d/point.py +++ b/dimos/perception/detection/type/detection2d/point.py @@ -31,14 +31,14 @@ Pose2D, ) -from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.types.timestamped import to_ros_stamp if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass diff --git a/dimos/perception/detection/type/detection2d/seg.py b/dimos/perception/detection/type/detection2d/seg.py index 5d4d55d0c3..aca1e34b7e 100644 --- a/dimos/perception/detection/type/detection2d/seg.py +++ b/dimos/perception/detection/type/detection2d/seg.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from ultralytics.engine.results import Results - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass diff --git a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py index 83487d2c25..4897d8d034 100644 --- a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from dimos.perception.detection.type import ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D def test_from_ros_detection2d_array(get_moment_2d) -> None: diff --git a/dimos/perception/detection/type/detection2d/test_person.py b/dimos/perception/detection/type/detection2d/test_person.py index 06c5883ae2..988222e120 100644 --- a/dimos/perception/detection/type/detection2d/test_person.py +++ b/dimos/perception/detection/type/detection2d/test_person.py @@ -17,7 +17,7 @@ def test_person_ros_confidence() -> None: """Test that Detection2DPerson preserves confidence when converting to ROS format.""" - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.type.detection2d.person import Detection2DPerson from dimos.utils.data import get_data diff --git a/dimos/perception/detection/type/detection3d/__init__.py b/dimos/perception/detection/type/detection3d/__init__.py deleted file mode 100644 index 53ab73259e..0000000000 --- a/dimos/perception/detection/type/detection3d/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.perception.detection.type.detection3d.base import Detection3D -from dimos.perception.detection.type.detection3d.bbox import Detection3DBBox -from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC -from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC -from dimos.perception.detection.type.detection3d.pointcloud_filters import ( - PointCloudFilter, - height_filter, - radius_outlier, - raycast, - statistical, -) - -__all__ = [ - "Detection3D", - "Detection3DBBox", - "Detection3DPC", - "ImageDetections3DPC", - "PointCloudFilter", - "height_filter", - "radius_outlier", - "raycast", - "statistical", -] diff --git a/dimos/perception/detection/type/detection3d/base.py b/dimos/perception/detection/type/detection3d/base.py index a5dbb742b8..afe37aac6e 100644 --- a/dimos/perception/detection/type/detection3d/base.py +++ b/dimos/perception/detection/type/detection3d/base.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox if TYPE_CHECKING: diff --git a/dimos/perception/detection/type/detection3d/bbox.py b/dimos/perception/detection/type/detection3d/bbox.py index bdf2d27a7c..a3ae68a766 100644 --- a/dimos/perception/detection/type/detection3d/bbox.py +++ b/dimos/perception/detection/type/detection3d/bbox.py @@ -20,9 +20,13 @@ from dimos_lcm.vision_msgs import ObjectHypothesis, ObjectHypothesisWithPose -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection3D +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox diff --git a/dimos/perception/detection/type/detection3d/object.py b/dimos/perception/detection/type/detection3d/object.py index ec160c4a68..639ea73ae5 100644 --- a/dimos/perception/detection/type/detection3d/object.py +++ b/dimos/perception/detection/type/detection3d/object.py @@ -24,10 +24,15 @@ import numpy as np import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection3D as ROSDetection3D, Detection3DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection3D import Detection3D as ROSDetection3D +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.detection.type.detection2d.seg import Detection2DSeg from dimos.perception.detection.type.detection3d.base import Detection3D diff --git a/dimos/perception/detection/type/detection3d/pointcloud.py b/dimos/perception/detection/type/detection3d/pointcloud.py index 741b9c7498..5ddec06fd5 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud.py +++ b/dimos/perception/detection/type/detection3d/pointcloud.py @@ -33,8 +33,10 @@ import numpy as np from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection3d.base import Detection3D from dimos.perception.detection.type.detection3d.pointcloud_filters import ( PointCloudFilter, diff --git a/dimos/perception/detection/type/detection3d/pointcloud_filters.py b/dimos/perception/detection/type/detection3d/pointcloud_filters.py index 59ad6200d9..fdb2afeebb 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud_filters.py +++ b/dimos/perception/detection/type/detection3d/pointcloud_filters.py @@ -18,8 +18,8 @@ from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox # Filters take Detection2DBBox, PointCloud2, CameraInfo, Transform and return filtered PointCloud2 or None diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 12a1f4efb9..25cd45545a 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -20,14 +20,14 @@ from dimos_lcm.vision_msgs import Detection2DArray -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.utils import TableStr if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.type.detection2d.base import Detection2D T = TypeVar("T", bound=Detection2D) diff --git a/dimos/perception/detection/type/test_object3d.py b/dimos/perception/detection/type/test_object3d.py index 7057fbb9cb..ff8931e353 100644 --- a/dimos/perception/detection/type/test_object3d.py +++ b/dimos/perception/detection/type/test_object3d.py @@ -15,7 +15,7 @@ import pytest from dimos.perception.detection.moduleDB import Object3D -from dimos.perception.detection.type.detection3d import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC def test_first_object(first_object) -> None: diff --git a/dimos/perception/experimental/__init__.py b/dimos/perception/experimental/__init__.py deleted file mode 100644 index 39ef33521d..0000000000 --- a/dimos/perception/experimental/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Experimental perception modules.""" diff --git a/dimos/perception/experimental/temporal_memory/clip_filter.py b/dimos/perception/experimental/temporal_memory/clip_filter.py index d747899452..9bea000712 100644 --- a/dimos/perception/experimental/temporal_memory/clip_filter.py +++ b/dimos/perception/experimental/temporal_memory/clip_filter.py @@ -18,7 +18,7 @@ import numpy as np -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 11c90cda87..bdc7137ce7 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -30,9 +30,12 @@ from dimos.utils.logging_config import setup_logger +from .temporal_utils.parsers import parse_batch_distance_response +from .temporal_utils.prompts import build_batch_distance_estimation_prompt + if TYPE_CHECKING: from dimos.models.vl.base import VlModel - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image logger = setup_logger() @@ -564,7 +567,6 @@ def estimate_and_save_distances( """Estimate distances between entities using VLM and save to database.""" if not frame_image: return - from . import temporal_utils as tu enriched_entities: list[dict[str, Any]] = [] for entity in parsed.get("new_entities", []): @@ -593,8 +595,8 @@ def estimate_and_save_distances( if not pairs: return try: - response = vlm.query(frame_image, tu.build_batch_distance_estimation_prompt(pairs)) - for r in tu.parse_batch_distance_response(response, pairs): + response = vlm.query(frame_image, build_batch_distance_estimation_prompt(pairs)) + for r in parse_batch_distance_response(response, pairs): if r["category"] in ("near", "medium", "far"): self.add_distance( entity_a_id=r["entity_a_id"], diff --git a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py index fc2c9c8a79..4c910a1b88 100644 --- a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py +++ b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 8841d3a6b0..d4e343872b 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -38,16 +38,16 @@ from dimos.core.stream import In, Out from dimos.models.vl.base import VlModel from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.sensor_msgs.Image import sharpness_barrier +from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.msgs.visualization_msgs.EntityMarkers import EntityMarkers, Marker from dimos.utils.logging_config import get_run_log_dir, setup_logger -from . import temporal_utils as tu from .clip_filter import CLIP_AVAILABLE, adaptive_keyframes from .entity_graph_db import EntityGraphDB from .frame_window_accumulator import Frame, FrameWindowAccumulator from .temporal_state import TemporalState +from .temporal_utils.graph_utils import build_graph_context, extract_time_window +from .temporal_utils.helpers import is_scene_stale from .window_analyzer import WindowAnalyzer try: @@ -376,7 +376,7 @@ def _analyze_window(self) -> None: w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s # Skip stale scenes (frames too close together / camera not moving) - if tu.is_scene_stale(window_frames, self.config.stale_scene_threshold): + if is_scene_stale(window_frames, self.config.stale_scene_threshold): logger.info(f"[temporal-memory] skipping stale window [{w_start:.1f}-{w_end:.1f}s]") return @@ -553,13 +553,13 @@ def query(self, question: str) -> str: # Graph context if self._graph_db: - time_window_s = tu.extract_time_window(question) + time_window_s = extract_time_window(question) all_entity_ids = [ e["id"] for e in snap.entity_roster if isinstance(e, dict) and "id" in e ] if all_entity_ids: logger.info(f"query: building graph context for {len(all_entity_ids)} entities") - graph_context = tu.build_graph_context( + graph_context = build_graph_context( graph_db=self._graph_db, entity_ids=all_entity_ids, time_window_s=time_window_s, diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py deleted file mode 100644 index d8119a5159..0000000000 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Temporal memory utilities.""" - -from .graph_utils import build_graph_context, extract_time_window -from .helpers import clamp_text, format_timestamp, is_scene_stale, next_entity_id_hint -from .parsers import parse_batch_distance_response, parse_window_response -from .prompts import ( - WINDOW_RESPONSE_SCHEMA, - build_batch_distance_estimation_prompt, - build_distance_estimation_prompt, - build_query_prompt, - build_summary_prompt, - build_window_prompt, - get_structured_output_format, -) - -__all__ = [ - "WINDOW_RESPONSE_SCHEMA", - "build_batch_distance_estimation_prompt", - "build_distance_estimation_prompt", - "build_graph_context", - "build_query_prompt", - "build_summary_prompt", - "build_window_prompt", - "clamp_text", - "extract_time_window", - "format_timestamp", - "get_structured_output_format", - "is_scene_stale", - "next_entity_id_hint", - "parse_batch_distance_response", - "parse_window_response", -] diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index 5b37b66770..81df107ecf 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -34,15 +34,17 @@ from dimos.core.stream import Out from dimos.core.transport import LCMTransport from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.experimental.temporal_memory import ( +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.experimental.temporal_memory.entity_graph_db import EntityGraphDB +from dimos.perception.experimental.temporal_memory.frame_window_accumulator import ( Frame, FrameWindowAccumulator, +) +from dimos.perception.experimental.temporal_memory.temporal_memory import ( TemporalMemory, TemporalMemoryConfig, - TemporalState, ) -from dimos.perception.experimental.temporal_memory.entity_graph_db import EntityGraphDB +from dimos.perception.experimental.temporal_memory.temporal_state import TemporalState from dimos.perception.experimental.temporal_memory.temporal_utils.graph_utils import ( extract_time_window, ) diff --git a/dimos/perception/experimental/temporal_memory/window_analyzer.py b/dimos/perception/experimental/temporal_memory/window_analyzer.py index cd01a3056d..3c233f8e5b 100644 --- a/dimos/perception/experimental/temporal_memory/window_analyzer.py +++ b/dimos/perception/experimental/temporal_memory/window_analyzer.py @@ -25,11 +25,17 @@ from dimos.utils.logging_config import setup_logger -from . import temporal_utils as tu +from .temporal_utils.parsers import parse_window_response +from .temporal_utils.prompts import ( + build_query_prompt, + build_summary_prompt, + build_window_prompt, + get_structured_output_format, +) if TYPE_CHECKING: from dimos.models.vl.base import VlModel - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from .frame_window_accumulator import Frame @@ -87,14 +93,14 @@ def analyze_window( w_end: float, ) -> AnalysisResult | None: """Run VLM window analysis. Returns None on failure.""" - query = tu.build_window_prompt( + query = build_window_prompt( w_start=w_start, w_end=w_end, frame_count=len(frames), state=state_dict, ) try: - fmt = tu.get_structured_output_format() + fmt = get_structured_output_format() if len(frames) > 1: responses = self._vlm.query_batch( [f.image for f in frames], query, response_format=fmt @@ -109,7 +115,7 @@ def analyze_window( if raw is None: return None - parsed = tu.parse_window_response(raw, w_start, w_end, len(frames)) + parsed = parse_window_response(raw, w_start, w_end, len(frames)) return AnalysisResult(parsed=parsed, raw_vlm_response=raw, w_start=w_start, w_end=w_end) # It's called from the orchestrator, not here. @@ -124,7 +130,7 @@ def update_summary( if not chunk_buffer or not latest_frame: return None - prompt = tu.build_summary_prompt( + prompt = build_summary_prompt( rolling_summary=rolling_summary, chunk_windows=chunk_buffer, ) @@ -143,7 +149,7 @@ def answer_query( latest_frame: Image, ) -> QueryResult | None: """Answer a user query. Returns None on failure.""" - prompt = tu.build_query_prompt(question=question, context=context) + prompt = build_query_prompt(question=question, context=context) try: raw = self._vlm.query(latest_frame, prompt) return QueryResult(answer=raw.strip(), raw_vlm_response=raw) diff --git a/dimos/perception/object_scene_registration.py b/dimos/perception/object_scene_registration.py index ee7b87b534..5fb1748032 100644 --- a/dimos/perception/object_scene_registration.py +++ b/dimos/perception/object_scene_registration.py @@ -24,14 +24,16 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode from dimos.perception.detection.objectDB import ObjectDB -from dimos.perception.detection.type import ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection3d.object import ( Object, Object as DetObject, diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 29a9ecc034..6afc5e0814 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -31,15 +31,16 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import ( - CameraInfo, - Image, - ImageFormat, -) -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray +from dimos.protocol.tf.tf import TF from dimos.types.timestamped import align_timestamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 03f3991081..a53d331aef 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -35,9 +35,9 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.INFO) diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index da35577d0d..317a58dba0 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -24,12 +24,16 @@ from dimos.core.core import rpc from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.protocol.tf import TF +from dimos.protocol.tf.tf import TF from dimos.types.timestamped import align_timestamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( diff --git a/dimos/perception/perceive_loop_skill.py b/dimos/perception/perceive_loop_skill.py index 0d84e40897..4532e61c2e 100644 --- a/dimos/perception/perceive_loop_skill.py +++ b/dimos/perception/perceive_loop_skill.py @@ -26,8 +26,7 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.models.vl.create import create -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.sensor_msgs.Image import sharpness_window +from dimos.msgs.sensor_msgs.Image import Image, sharpness_window from dimos.utils.logging_config import setup_logger from dimos.utils.reactive import backpressure diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 0cb4ab74c1..fe6d7d50e0 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -27,7 +27,6 @@ from reactivex import Observable, interval, operators as ops from reactivex.disposable import Disposable -from dimos import spec from dimos.agents_deprecated.memory.image_embedding import ImageEmbeddingProvider from dimos.agents_deprecated.memory.spatial_vector_db import SpatialVectorDB from dimos.agents_deprecated.memory.visual_memory import VisualMemory @@ -36,12 +35,13 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image +from dimos.spec.perception import Camera from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 _OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" _MEMORY_DIR = _OUTPUT_DIR / "memory" @@ -577,7 +577,7 @@ def query_tagged_location(self, query: str) -> RobotLocation | None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - camera: spec.Camera, + camera: Camera, ): spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") # type: ignore[attr-defined] spatial_memory.color_image.connect(camera.color_image) diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index 433896aefe..322513d459 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -22,7 +22,7 @@ from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.perception.spatial_perception import SpatialMemory from dimos.stream.video_provider import VideoProvider diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index 22aa4d4ce8..d8567036bf 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -24,13 +24,13 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import Out from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger() diff --git a/dimos/protocol/__init__.py b/dimos/protocol/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/protocol/encode/__init__.py b/dimos/protocol/encode/encoder.py similarity index 82% rename from dimos/protocol/encode/__init__.py rename to dimos/protocol/encode/encoder.py index 87386a09e5..b6e00e4b1c 100644 --- a/dimos/protocol/encode/__init__.py +++ b/dimos/protocol/encode/encoder.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod import json from typing import Generic, Protocol, TypeVar diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py deleted file mode 100644 index 94a58b60de..0000000000 --- a/dimos/protocol/pubsub/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import dimos.protocol.pubsub.impl.lcmpubsub as lcm -from dimos.protocol.pubsub.impl.memory import Memory -from dimos.protocol.pubsub.spec import PubSub - -__all__ = [ - "Memory", - "PubSub", - "lcm", -] diff --git a/dimos/protocol/pubsub/encoders.py b/dimos/protocol/pubsub/encoders.py index 6b2056fa8b..69aa328765 100644 --- a/dimos/protocol/pubsub/encoders.py +++ b/dimos/protocol/pubsub/encoders.py @@ -20,8 +20,8 @@ import pickle from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast -from dimos.msgs import DimosMsg -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.protocol import DimosMsg +from dimos.msgs.sensor_msgs.Image import Image if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/protocol/pubsub/impl/__init__.py b/dimos/protocol/pubsub/impl/__init__.py deleted file mode 100644 index 63a5bfa6d6..0000000000 --- a/dimos/protocol/pubsub/impl/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from dimos.protocol.pubsub.impl.lcmpubsub import ( - LCM as LCM, - LCMPubSubBase as LCMPubSubBase, - PickleLCM as PickleLCM, -) -from dimos.protocol.pubsub.impl.memory import Memory as Memory diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index 4e792f5965..50c7c49f2f 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -20,7 +20,7 @@ import threading from typing import Any -from dimos.msgs import DimosMsg +from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.encoders import ( JpegEncoderMixin, LCMEncoderMixin, @@ -63,7 +63,7 @@ def from_channel_str(channel: str, default_lcm_type: type[DimosMsg] | None = Non Channel format: /topic#module.ClassName Falls back to default_lcm_type if type cannot be parsed. """ - from dimos.msgs import resolve_msg_type + from dimos.msgs.helpers import resolve_msg_type if "#" not in channel: return Topic(topic=channel, lcm_type=default_lcm_type) diff --git a/dimos/protocol/pubsub/impl/memory.py b/dimos/protocol/pubsub/impl/memory.py index 3425a5ee3d..25e10efe32 100644 --- a/dimos/protocol/pubsub/impl/memory.py +++ b/dimos/protocol/pubsub/impl/memory.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any -from dimos.protocol import encode +from dimos.protocol.encode import encoder as encode from dimos.protocol.pubsub.encoders import PubSubEncoderMixin from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/impl/rospubsub.py b/dimos/protocol/pubsub/impl/rospubsub.py index 1a3c989a4d..1e18b3759a 100644 --- a/dimos/protocol/pubsub/impl/rospubsub.py +++ b/dimos/protocol/pubsub/impl/rospubsub.py @@ -37,7 +37,7 @@ import uuid -from dimos.msgs import DimosMsg +from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.impl.rospubsub_conversion import ( derive_ros_type, dimos_to_ros, diff --git a/dimos/protocol/pubsub/impl/rospubsub_conversion.py b/dimos/protocol/pubsub/impl/rospubsub_conversion.py index 275033a5ac..150c3eeb8f 100644 --- a/dimos/protocol/pubsub/impl/rospubsub_conversion.py +++ b/dimos/protocol/pubsub/impl/rospubsub_conversion.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: - from dimos.msgs import DimosMsg + from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.impl.rospubsub import ROSMessage diff --git a/dimos/protocol/pubsub/impl/test_lcmpubsub.py b/dimos/protocol/pubsub/impl/test_lcmpubsub.py index ea80b4c445..ba29c70958 100644 --- a/dimos/protocol/pubsub/impl/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/test_lcmpubsub.py @@ -18,7 +18,9 @@ import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import ( LCM, LCMPubSubBase, diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index 5f574065ba..ef9df74227 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -28,7 +28,7 @@ # Add msg_name to LCM PointStamped for testing nested message conversion PointStamped.msg_name = "geometry_msgs.PointStamped" from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay def ros_node(): diff --git a/dimos/protocol/pubsub/test_pattern_sub.py b/dimos/protocol/pubsub/test_pattern_sub.py index cdbce5d5a6..ac94ba1b3b 100644 --- a/dimos/protocol/pubsub/test_pattern_sub.py +++ b/dimos/protocol/pubsub/test_pattern_sub.py @@ -24,7 +24,9 @@ import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, LCMPubSubBase, Topic from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub, PubSub diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index a240319fdf..e36741bbfd 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -23,7 +23,7 @@ import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.impl.memory import Memory diff --git a/dimos/protocol/rpc/__init__.py b/dimos/protocol/rpc/__init__.py deleted file mode 100644 index 1eb892d956..0000000000 --- a/dimos/protocol/rpc/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.protocol.rpc.pubsubrpc import LCMRPC, ShmRPC -from dimos.protocol.rpc.spec import RPCClient, RPCServer, RPCSpec - -__all__ = ["LCMRPC", "RPCClient", "RPCServer", "RPCSpec", "ShmRPC"] diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py index f31d20cf19..5baa5ac40c 100644 --- a/dimos/protocol/rpc/test_lcmrpc.py +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -17,7 +17,7 @@ import pytest from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH -from dimos.protocol.rpc import LCMRPC +from dimos.protocol.rpc.pubsubrpc import LCMRPC @pytest.fixture diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py deleted file mode 100644 index ed6caf93c2..0000000000 --- a/dimos/protocol/service/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import BaseConfig, Configurable, Service - -__all__ = ( - "BaseConfig", - "Configurable", - "LCMService", - "Service", -) diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 9a563addb1..0211b34129 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -25,7 +25,8 @@ import lcm as lcm_mod from dimos.protocol.service.spec import BaseConfig, Service -from dimos.protocol.service.system_configurator import configure_system, lcm_configurators +from dimos.protocol.service.system_configurator.base import configure_system +from dimos.protocol.service.system_configurator.lcm_config import lcm_configurators from dimos.utils.logging_config import setup_logger if sys.version_info < (3, 13): diff --git a/dimos/protocol/service/system_configurator/__init__.py b/dimos/protocol/service/system_configurator/lcm_config.py similarity index 54% rename from dimos/protocol/service/system_configurator/__init__.py rename to dimos/protocol/service/system_configurator/lcm_config.py index 31b5af4d8c..72f1e5d774 100644 --- a/dimos/protocol/service/system_configurator/__init__.py +++ b/dimos/protocol/service/system_configurator/lcm_config.py @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""System configurator package — re-exports for backward compatibility.""" +"""Platform-appropriate LCM system configurators.""" import platform -from dimos.protocol.service.system_configurator.base import ( - SystemConfigurator, - configure_system, - sudo_run, -) -from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.base import SystemConfigurator from dimos.protocol.service.system_configurator.lcm import ( - IDEAL_RMEM_SIZE, BufferConfiguratorLinux, BufferConfiguratorMacOS, MaxFileConfiguratorMacOS, @@ -33,17 +27,6 @@ from dimos.protocol.service.system_configurator.libpython import LibPythonConfiguratorMacOS -# TODO: This is a configurator API issue and inserted here temporarily -# -# We need to use different configurators based on the underlying OS -# -# We should have separation of concerns, nothing but configurators themselves care about the OS in this context -# -# So configurators with multi-os behavior should be responsible for the right per-OS behaviour, and -# not external systems -# -# We might want to have some sort of recursive configurators -# def lcm_configurators() -> list[SystemConfigurator]: """Return the platform-appropriate LCM system configurators.""" system = platform.system() @@ -56,23 +39,7 @@ def lcm_configurators() -> list[SystemConfigurator]: return [ MulticastConfiguratorMacOS(loopback_interface="lo0"), BufferConfiguratorMacOS(), - MaxFileConfiguratorMacOS(), # TODO: this is not LCM related and shouldn't be here at all + MaxFileConfiguratorMacOS(), LibPythonConfiguratorMacOS(), ] return [] - - -__all__ = [ - "IDEAL_RMEM_SIZE", - "BufferConfiguratorLinux", - "BufferConfiguratorMacOS", - "ClockSyncConfigurator", - "LibPythonConfiguratorMacOS", - "MaxFileConfiguratorMacOS", - "MulticastConfiguratorLinux", - "MulticastConfiguratorMacOS", - "SystemConfigurator", - "configure_system", - "lcm_configurators", - "sudo_run", -] diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index 78085e2363..cbab6ff3ab 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -25,14 +25,14 @@ LCMService, autoconf, ) -from dimos.protocol.service.system_configurator import ( +from dimos.protocol.service.system_configurator.lcm import ( BufferConfiguratorLinux, BufferConfiguratorMacOS, - LibPythonConfiguratorMacOS, MaxFileConfiguratorMacOS, MulticastConfiguratorLinux, MulticastConfiguratorMacOS, ) +from dimos.protocol.service.system_configurator.libpython import LibPythonConfiguratorMacOS # autoconf tests @@ -40,7 +40,8 @@ class TestConfigureSystemForLcm: def test_creates_linux_checks_on_linux(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Linux" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Linux", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf() @@ -53,7 +54,8 @@ def test_creates_linux_checks_on_linux(self) -> None: def test_creates_macos_checks_on_darwin(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Darwin" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Darwin", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf() @@ -68,7 +70,8 @@ def test_creates_macos_checks_on_darwin(self) -> None: def test_passes_check_only_flag(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Linux" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Linux", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf(check_only=True) @@ -77,7 +80,8 @@ def test_passes_check_only_flag(self) -> None: def test_logs_error_on_unsupported_system(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Windows" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Windows", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: with patch("dimos.protocol.service.lcmservice.logger") as mock_logger: diff --git a/dimos/protocol/service/test_system_configurator.py b/dimos/protocol/service/test_system_configurator.py index 1bd44aa5e2..715d9eede7 100644 --- a/dimos/protocol/service/test_system_configurator.py +++ b/dimos/protocol/service/test_system_configurator.py @@ -19,22 +19,22 @@ import pytest -from dimos.protocol.service.system_configurator import ( +from dimos.protocol.service.system_configurator.base import ( + SystemConfigurator, + _is_root_user, + _read_sysctl_int, + _write_sysctl_int, + configure_system, + sudo_run, +) +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.lcm import ( IDEAL_RMEM_SIZE, BufferConfiguratorLinux, BufferConfiguratorMacOS, - ClockSyncConfigurator, MaxFileConfiguratorMacOS, MulticastConfiguratorLinux, MulticastConfiguratorMacOS, - SystemConfigurator, - configure_system, - sudo_run, -) -from dimos.protocol.service.system_configurator.base import ( - _is_root_user, - _read_sysctl_int, - _write_sysctl_int, ) # Helper function tests diff --git a/dimos/protocol/tf/__init__.py b/dimos/protocol/tf/__init__.py deleted file mode 100644 index cb00dbde3c..0000000000 --- a/dimos/protocol/tf/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.protocol.tf.tf import LCMTF, TF, MultiTBuffer, PubSubTF, TBuffer, TFConfig, TFSpec - -__all__ = ["LCMTF", "TF", "MultiTBuffer", "PubSubTF", "TBuffer", "TFConfig", "TFSpec"] diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index c1f0b13fa2..b0843bfccd 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -19,8 +19,11 @@ import pytest -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.protocol.tf import TF, MultiTBuffer, TBuffer +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer # from https://foxglove.dev/blog/understanding-ros-transforms diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 1b5ccadf3c..97b2132bbb 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -21,8 +21,9 @@ from typing import TypeVar from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.spec import PubSub from dimos.protocol.service.spec import BaseConfig, Service diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index bf2885958d..aec1f947ce 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -14,7 +14,7 @@ from datetime import datetime -from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFConfig, TFSpec diff --git a/dimos/robot/__init__.py b/dimos/robot/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/drone/__init__.py b/dimos/robot/drone/__init__.py deleted file mode 100644 index 828059e99d..0000000000 --- a/dimos/robot/drone/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generic drone module for MAVLink-based drones.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "camera_module": ["DroneCameraModule"], - "connection_module": ["DroneConnectionModule"], - "mavlink_connection": ["MavlinkConnection"], - }, -) diff --git a/dimos/robot/drone/blueprints/__init__.py b/dimos/robot/drone/blueprints/__init__.py deleted file mode 100644 index d011c6e4fb..0000000000 --- a/dimos/robot/drone/blueprints/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DimOS Drone blueprints.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "basic.drone_basic": ["drone_basic"], - "agentic.drone_agentic": ["drone_agentic"], - }, -) diff --git a/dimos/robot/drone/blueprints/agentic/__init__.py b/dimos/robot/drone/blueprints/agentic/__init__.py deleted file mode 100644 index a7386b8f45..0000000000 --- a/dimos/robot/drone/blueprints/agentic/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Agentic drone blueprint.""" - -from dimos.robot.drone.blueprints.agentic.drone_agentic import drone_agentic - -__all__ = ["drone_agentic"] diff --git a/dimos/robot/drone/blueprints/basic/__init__.py b/dimos/robot/drone/blueprints/basic/__init__.py deleted file mode 100644 index 3bf4ec60ff..0000000000 --- a/dimos/robot/drone/blueprints/basic/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Basic drone blueprint.""" - -from dimos.robot.drone.blueprints.basic.drone_basic import drone_basic - -__all__ = ["drone_basic"] diff --git a/dimos/robot/drone/camera_module.py b/dimos/robot/drone/camera_module.py index 63389aa358..5343549c66 100644 --- a/dimos/robot/drone/camera_module.py +++ b/dimos/robot/drone/camera_module.py @@ -26,9 +26,9 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.std_msgs.Header import Header from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index c606e7467e..863f719bad 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -28,9 +28,13 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.mapping.models import LatLon +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.drone.dji_video_stream import DJIDroneVideoStream from dimos.robot.drone.mavlink_connection import MavlinkConnection from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/drone/dji_video_stream.py b/dimos/robot/drone/dji_video_stream.py index 1810fd4212..60618ae712 100644 --- a/dimos/robot/drone/dji_video_stream.py +++ b/dimos/robot/drone/dji_video_stream.py @@ -26,7 +26,7 @@ import numpy as np from reactivex import Observable, Subject -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -214,7 +214,7 @@ def get_stream(self) -> Observable[Image]: # type: ignore[override] """ from reactivex import operators as ops - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay def _fix_format(img: Image) -> Image: if img.format == ImageFormat.BGR: diff --git a/dimos/robot/drone/drone_tracking_module.py b/dimos/robot/drone/drone_tracking_module.py index 276b636633..5798db374b 100644 --- a/dimos/robot/drone/drone_tracking_module.py +++ b/dimos/robot/drone/drone_tracking_module.py @@ -29,8 +29,9 @@ from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.msgs.geometry_msgs import Twist, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.robot.drone.drone_visual_servoing_controller import ( DroneVisualServoingController, PIDParams, diff --git a/dimos/robot/drone/mavlink_connection.py b/dimos/robot/drone/mavlink_connection.py index d8a7c97c4a..076d9cd369 100644 --- a/dimos/robot/drone/mavlink_connection.py +++ b/dimos/robot/drone/mavlink_connection.py @@ -23,7 +23,10 @@ from pymavlink import mavutil # type: ignore[import-not-found, import-untyped] from reactivex import Subject -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.INFO) @@ -1028,7 +1031,7 @@ def __init__(self, connection_string: str) -> None: class FakeMavlink: def __init__(self) -> None: from dimos.utils.data import get_data - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay get_data("drone") diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index 88c45c9aa8..0b30c22c35 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -25,8 +25,10 @@ import numpy as np -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.robot.drone.connection_module import DroneConnectionModule from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream @@ -192,7 +194,7 @@ class TestReplayMode(unittest.TestCase): def test_fake_mavlink_connection(self) -> None: """Test FakeMavlinkConnection replays messages correctly.""" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: # Mock the replay stream MagicMock() mock_messages = [ @@ -218,7 +220,7 @@ def test_fake_mavlink_connection(self) -> None: def test_fake_video_stream_no_throttling(self) -> None: """Test FakeDJIVideoStream returns replay stream with format fix.""" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: mock_stream = MagicMock() mock_replay.return_value.stream.return_value = mock_stream @@ -280,7 +282,7 @@ def test_connection_module_replay_with_messages(self) -> None: os.environ["DRONE_CONNECTION"] = "replay" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: # Set up MAVLink replay stream mavlink_messages = [ {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, @@ -433,7 +435,7 @@ def tearDown(self) -> None: self.foxglove_patch.stop() @patch("dimos.robot.drone.drone.ModuleCoordinator") - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") def test_full_system_with_replay(self, mock_replay, mock_coordinator_class) -> None: """Test full drone system initialization and operation with replay mode.""" # Set up mock replay data @@ -567,7 +569,7 @@ def deploy_side_effect(module_class, **kwargs): class TestDroneControlCommands(unittest.TestCase): """Test drone control commands with FakeMavlinkConnection.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: """Test arm and disarm commands work with fake connection.""" @@ -586,7 +588,7 @@ def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: result = conn.disarm() self.assertIsInstance(result, bool) # Should return bool without crashing - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: """Test takeoff and land commands with fake connection.""" @@ -605,7 +607,7 @@ def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: result = conn.land() self.assertIsNotNone(result) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_set_mode_command(self, mock_get_data, mock_replay) -> None: """Test flight mode setting with fake connection.""" @@ -626,7 +628,7 @@ def test_set_mode_command(self, mock_get_data, mock_replay) -> None: class TestDronePerception(unittest.TestCase): """Test drone perception capabilities.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_video_stream_replay(self, mock_get_data, mock_replay) -> None: """Test video stream works with replay data.""" @@ -696,7 +698,7 @@ def piped_subscribe(callback): # type: ignore[no-untyped-def] class TestDroneMovementAndOdometry(unittest.TestCase): """Test drone movement commands and odometry.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: """Test movement commands are properly converted from ROS to NED.""" @@ -716,7 +718,7 @@ def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: # Movement should be converted to NED internally # The fake connection doesn't actually send commands, but it should not crash - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_odometry_from_replay(self, mock_get_data, mock_replay) -> None: """Test odometry is properly generated from replay messages.""" @@ -763,7 +765,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIsNotNone(odom.orientation) self.assertEqual(odom.frame_id, "world") - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_position_integration_indoor(self, mock_get_data, mock_replay) -> None: """Test position integration for indoor flight without GPS.""" @@ -808,7 +810,7 @@ def replay_stream_subscribe(callback) -> None: class TestDroneStatusAndTelemetry(unittest.TestCase): """Test drone status and telemetry reporting.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_status_extraction(self, mock_get_data, mock_replay) -> None: """Test status is properly extracted from MAVLink messages.""" @@ -853,7 +855,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIn("altitude", status) self.assertIn("heading", status) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_telemetry_json_publishing(self, mock_get_data, mock_replay) -> None: """Test full telemetry is published as JSON.""" @@ -907,7 +909,7 @@ def replay_stream_subscribe(callback) -> None: class TestFlyToErrorHandling(unittest.TestCase): """Test fly_to() error handling paths.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: """flying_to_target=True rejects concurrent fly_to() calls.""" @@ -921,7 +923,7 @@ def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: result = conn.fly_to(37.0, -122.0, 10.0) self.assertIn("Already flying to target", result) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_error_when_not_connected(self, mock_get_data, mock_replay) -> None: """connected=False returns error immediately.""" diff --git a/dimos/robot/manipulators/__init__.py b/dimos/robot/manipulators/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/piper/__init__.py b/dimos/robot/manipulators/piper/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/piper/blueprints.py b/dimos/robot/manipulators/piper/blueprints.py index 68e02fc994..ead27fd54b 100644 --- a/dimos/robot/manipulators/piper/blueprints.py +++ b/dimos/robot/manipulators/piper/blueprints.py @@ -27,9 +27,11 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.manipulation.manipulation_module import manipulation_module -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module from dimos.utils.data import LfsPath, get_data diff --git a/dimos/robot/manipulators/xarm/__init__.py b/dimos/robot/manipulators/xarm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/xarm/blueprints.py b/dimos/robot/manipulators/xarm/blueprints.py index 9a1732217b..e699057b44 100644 --- a/dimos/robot/manipulators/xarm/blueprints.py +++ b/dimos/robot/manipulators/xarm/blueprints.py @@ -32,8 +32,8 @@ _make_xarm7_config, ) from dimos.manipulation.manipulation_module import manipulation_module -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module from dimos.utils.data import LfsPath diff --git a/dimos/robot/unitree/__init__.py b/dimos/robot/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/b1/__init__.py b/dimos/robot/unitree/b1/__init__.py deleted file mode 100644 index db85984070..0000000000 --- a/dimos/robot/unitree/b1/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. - -"""Unitree B1 robot module.""" - -from .unitree_b1 import UnitreeB1 - -__all__ = ["UnitreeB1"] diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 445044020d..11af31b296 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -28,9 +28,11 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.std_msgs.Int32 import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree/b1/joystick_module.py b/dimos/robot/unitree/b1/joystick_module.py index 9fbfd84f1e..234ff129c9 100644 --- a/dimos/robot/unitree/b1/joystick_module.py +++ b/dimos/robot/unitree/b1/joystick_module.py @@ -28,8 +28,10 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.std_msgs.Int32 import Int32 class JoystickModule(Module): diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index e43a3124dc..f1ff5ad861 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -25,7 +25,8 @@ import threading import time -from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs.Int32 import Int32 from .connection import MockB1ConnectionModule diff --git a/dimos/robot/unitree/b1/unitree_b1.py b/dimos/robot/unitree/b1/unitree_b1.py index 6b374d1d5b..9a6d04a7ff 100644 --- a/dimos/robot/unitree/b1/unitree_b1.py +++ b/dimos/robot/unitree/b1/unitree_b1.py @@ -26,9 +26,10 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource from dimos.core.transport import LCMTransport, ROSTransport -from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.std_msgs.Int32 import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.robot.robot import Robot from dimos.robot.unitree.b1.connection import ( diff --git a/dimos/robot/unitree/connection.py b/dimos/robot/unitree/connection.py index ff73d922ee..7e60080f01 100644 --- a/dimos/robot/unitree/connection.py +++ b/dimos/robot/unitree/connection.py @@ -35,9 +35,11 @@ ) from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.lowstate import LowStateMsg from dimos.robot.unitree.type.odometry import Odometry diff --git a/dimos/robot/unitree/g1/blueprints/__init__.py b/dimos/robot/unitree/g1/blueprints/__init__.py deleted file mode 100644 index ebc18da8d3..0000000000 --- a/dimos/robot/unitree/g1/blueprints/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cascaded G1 blueprints split into focused modules.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "agentic._agentic_skills": ["_agentic_skills"], - "agentic.unitree_g1_agentic": ["unitree_g1_agentic"], - "agentic.unitree_g1_agentic_sim": ["unitree_g1_agentic_sim"], - "agentic.unitree_g1_full": ["unitree_g1_full"], - "basic.unitree_g1_basic": ["unitree_g1_basic"], - "basic.unitree_g1_basic_sim": ["unitree_g1_basic_sim"], - "basic.unitree_g1_joystick": ["unitree_g1_joystick"], - "perceptive._perception_and_memory": ["_perception_and_memory"], - "perceptive.unitree_g1": ["unitree_g1"], - "perceptive.unitree_g1_detection": ["unitree_g1_detection"], - "perceptive.unitree_g1_shm": ["unitree_g1_shm"], - "perceptive.unitree_g1_sim": ["unitree_g1_sim"], - "primitive.uintree_g1_primitive_no_nav": ["uintree_g1_primitive_no_nav", "basic_no_nav"], - }, -) diff --git a/dimos/robot/unitree/g1/blueprints/agentic/__init__.py b/dimos/robot/unitree/g1/blueprints/agentic/__init__.py deleted file mode 100644 index 5e6db90d91..0000000000 --- a/dimos/robot/unitree/g1/blueprints/agentic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agentic blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/basic/__init__.py b/dimos/robot/unitree/g1/blueprints/basic/__init__.py deleted file mode 100644 index 87e6586f56..0000000000 --- a/dimos/robot/unitree/g1/blueprints/basic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Basic blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py b/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py deleted file mode 100644 index 9bd838e8b8..0000000000 --- a/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Perceptive blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py index 25bff97c73..18884bd7af 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py @@ -22,10 +22,11 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera import zed -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.hardware.sensors.camera.zed import compat as zed +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.perception.detection.moduleDB import ObjectDBModule, detection_db_module diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 5ee4d4c9d1..be67194b62 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -18,7 +18,7 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import pSHMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.foxglove_bridge import foxglove_bridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 diff --git a/dimos/robot/unitree/g1/blueprints/primitive/__init__.py b/dimos/robot/unitree/g1/blueprints/primitive/__init__.py deleted file mode 100644 index 833f767728..0000000000 --- a/dimos/robot/unitree/g1/blueprints/primitive/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Primitive blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index c47fdc377b..242fcaf38f 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -22,16 +22,24 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera import zed from dimos.hardware.sensors.camera.module import camera_module # type: ignore[attr-defined] from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.hardware.sensors.camera.zed import compat as zed from dimos.mapping.costmapper import cost_mapper from dimos.mapping.voxels import voxel_mapper -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.nav_msgs import Odometry, Path -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Bool -from dimos.navigation.frontier_exploration import wavefront_frontier_explorer +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Bool import Bool +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( + wavefront_frontier_explorer, +) from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index 94f725ac7e..1f3788de98 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -19,13 +19,13 @@ from pydantic import Field from reactivex.disposable import Disposable -from dimos import spec from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs.Twist import Twist from dimos.robot.unitree.connection import UnitreeWebRTCConnection +from dimos.spec.control import LocalPlanner from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -115,7 +115,7 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: g1_connection = G1Connection.blueprint -def deploy(dimos: ModuleCoordinator, ip: str, local_planner: spec.LocalPlanner) -> "ModuleProxy": +def deploy(dimos: ModuleCoordinator, ip: str, local_planner: LocalPlanner) -> "ModuleProxy": connection = dimos.deploy(G1Connection, ip=ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index 9226bb4e7f..206a689284 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -24,14 +24,14 @@ from dimos.core.core import rpc from dimos.core.module import ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - Vector3, -) -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.g1.connection import G1ConnectionBase from dimos.robot.unitree.mujoco_connection import MujocoConnection from dimos.robot.unitree.type.odometry import Odometry as SimOdometry diff --git a/dimos/robot/unitree/g1/skill_container.py b/dimos/robot/unitree/g1/skill_container.py index 2bd5bcdb49..b1342ca96d 100644 --- a/dimos/robot/unitree/g1/skill_container.py +++ b/dimos/robot/unitree/g1/skill_container.py @@ -22,7 +22,8 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module -from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/robot/unitree/go2/blueprints/__init__.py b/dimos/robot/unitree/go2/blueprints/__init__.py deleted file mode 100644 index cbc49694f3..0000000000 --- a/dimos/robot/unitree/go2/blueprints/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cascaded GO2 blueprints split into focused modules.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "agentic._common_agentic": ["_common_agentic"], - "agentic.unitree_go2_agentic": ["unitree_go2_agentic"], - "agentic.unitree_go2_agentic_huggingface": ["unitree_go2_agentic_huggingface"], - "agentic.unitree_go2_agentic_mcp": ["unitree_go2_agentic_mcp"], - "agentic.unitree_go2_agentic_ollama": ["unitree_go2_agentic_ollama"], - "agentic.unitree_go2_temporal_memory": ["unitree_go2_temporal_memory"], - "basic.unitree_go2_basic": ["_linux", "_mac", "unitree_go2_basic"], - "smart._with_jpeg": ["_with_jpeglcm"], - "smart.unitree_go2": ["unitree_go2"], - "smart.unitree_go2_detection": ["unitree_go2_detection"], - "smart.unitree_go2_ros": ["unitree_go2_ros"], - "smart.unitree_go2_spatial": ["unitree_go2_spatial"], - "smart.unitree_go2_vlm_stream_test": ["unitree_go2_vlm_stream_test"], - }, -) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/__init__.py b/dimos/robot/unitree/go2/blueprints/agentic/__init__.py deleted file mode 100644 index 84d1b41b23..0000000000 --- a/dimos/robot/unitree/go2/blueprints/agentic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agentic blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py index 13a1eec1ff..24ab47ad3b 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py @@ -15,7 +15,10 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig, temporal_memory +from dimos.perception.experimental.temporal_memory.temporal_memory import ( + TemporalMemoryConfig, + temporal_memory, +) from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic # This module is imported lazily by `get_by_name()` in the CLI run command, diff --git a/dimos/robot/unitree/go2/blueprints/basic/__init__.py b/dimos/robot/unitree/go2/blueprints/basic/__init__.py deleted file mode 100644 index 79964b0297..0000000000 --- a/dimos/robot/unitree/go2/blueprints/basic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Basic blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index ce8aef2222..3325290bf7 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -21,9 +21,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.protocol.service.system_configurator import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import go2_connection from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 015cfcdba4..908444b2fd 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -21,7 +21,7 @@ """ from dimos.core.blueprints import autoconnect -from dimos.protocol.service.system_configurator import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis from dimos.robot.unitree.go2.fleet_connection import go2_fleet_connection from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/go2/blueprints/smart/__init__.py b/dimos/robot/unitree/go2/blueprints/smart/__init__.py deleted file mode 100644 index 7d5bdbc3ab..0000000000 --- a/dimos/robot/unitree/go2/blueprints/smart/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Smart blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py b/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py index 9c77d599cf..a759b1ca50 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py +++ b/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py @@ -14,7 +14,7 @@ # limitations under the License. from dimos.core.transport import JpegLcmTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 _with_jpeglcm = unitree_go2.transports( diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 22743ac135..80e6ec701a 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -16,7 +16,9 @@ from dimos.core.blueprints import autoconnect from dimos.mapping.costmapper import cost_mapper from dimos.mapping.voxels import voxel_mapper -from dimos.navigation.frontier_exploration import wavefront_frontier_explorer +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( + wavefront_frontier_explorer, +) from dimos.navigation.replanning_a_star.module import replanning_a_star_planner from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py index f2edf2cb3b..a9bb7729ae 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py @@ -20,8 +20,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 from dimos.robot.unitree.go2.connection import GO2Connection diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py index a335b1e9af..b63b8f5f6c 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py @@ -14,8 +14,9 @@ # limitations under the License. from dimos.core.transport import ROSTransport -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 unitree_go2_ros = unitree_go2.transports( diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index c06028ec6f..38da7fb439 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -23,7 +23,6 @@ from reactivex.observable import Observable import rerun.blueprint as rrb -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.global_config import GlobalConfig @@ -31,18 +30,18 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, pSHMTransport +from dimos.spec.perception import Camera, Pointcloud if TYPE_CHECKING: from dimos.core.rpc_client import ModuleProxy -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - Vector3, -) -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.connection import UnitreeWebRTCConnection from dimos.utils.data import get_data from dimos.utils.decorators.decorators import simple_mcache @@ -184,7 +183,7 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de _Config = TypeVar("_Config", bound=ConnectionConfig, default=ConnectionConfig) -class GO2Connection(Module[_Config], spec.Camera, spec.Pointcloud): +class GO2Connection(Module[_Config], Camera, Pointcloud): default_config = ConnectionConfig # type: ignore[assignment] cmd_vel: In[Twist] diff --git a/dimos/robot/unitree/go2/fleet_connection.py b/dimos/robot/unitree/go2/fleet_connection.py index 24a95ec4d2..f0e904648a 100644 --- a/dimos/robot/unitree/go2/fleet_connection.py +++ b/dimos/robot/unitree/go2/fleet_connection.py @@ -37,7 +37,7 @@ from typing import Any as Self if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Twist + from dimos.msgs.geometry_msgs.Twist import Twist logger = setup_logger() diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 3cd03df785..86885bc446 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -22,7 +22,8 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" diff --git a/dimos/robot/unitree/modular/detect.py b/dimos/robot/unitree/modular/detect.py index 99faddc946..d6ed78d101 100644 --- a/dimos/robot/unitree/modular/detect.py +++ b/dimos/robot/unitree/modular/detect.py @@ -16,8 +16,9 @@ from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Header +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry @@ -71,8 +72,10 @@ def camera_info() -> CameraInfo: def transform_chain(odom_frame: Odometry) -> list: # type: ignore[type-arg] - from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 - from dimos.protocol.tf import TF + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + from dimos.protocol.tf.tf import TF camera_link = Transform( translation=Vector3(0.3, 0.0, 0.0), @@ -113,7 +116,7 @@ def broadcast( # type: ignore[no-untyped-def] ) from dimos.core.transport import LCMTransport - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped lidar_transport = LCMTransport("/lidar", PointCloud2) # type: ignore[var-annotated] odom_transport = LCMTransport("/odom", PoseStamped) # type: ignore[var-annotated] @@ -136,14 +139,14 @@ def broadcast( # type: ignore[no-untyped-def] def process_data(): # type: ignore[no-untyped-def] - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.module2D import ( # type: ignore[attr-defined] Detection2DModule, build_imageannotations, ) from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay get_data("unitree_office_walk") target = 1751591272.9654856 diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 3bc4e075f7..d7c98cffd3 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -35,8 +35,12 @@ from reactivex.disposable import Disposable from dimos.core.global_config import GlobalConfig -from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, ImageFormat, PointCloud2 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry from dimos.simulation.mujoco.constants import ( LAUNCHER_PATH, diff --git a/dimos/robot/unitree/rosnav.py b/dimos/robot/unitree/rosnav.py index 083c7413fe..b2fe42fde5 100644 --- a/dimos/robot/unitree/rosnav.py +++ b/dimos/robot/unitree/rosnav.py @@ -19,8 +19,8 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Joy +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Joy import Joy from dimos.msgs.std_msgs.Bool import Bool from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree/testing/__init__.py b/dimos/robot/unitree/testing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/testing/mock.py b/dimos/robot/unitree/testing/mock.py index 26e6a90018..4c5e52e4b0 100644 --- a/dimos/robot/unitree/testing/mock.py +++ b/dimos/robot/unitree/testing/mock.py @@ -21,7 +21,7 @@ from reactivex import from_iterable, interval, operators as ops from reactivex.observable import Observable -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar diff --git a/dimos/robot/unitree/testing/test_actors.py b/dimos/robot/unitree/testing/test_actors.py index ed0b05d664..77c3d7c56f 100644 --- a/dimos/robot/unitree/testing/test_actors.py +++ b/dimos/robot/unitree/testing/test_actors.py @@ -20,7 +20,7 @@ from dimos.core.module import Module from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.map import Map as Mapper diff --git a/dimos/robot/unitree/testing/test_tooling.py b/dimos/robot/unitree/testing/test_tooling.py index d1f2eeb169..40db01feee 100644 --- a/dimos/robot/unitree/testing/test_tooling.py +++ b/dimos/robot/unitree/testing/test_tooling.py @@ -19,7 +19,7 @@ from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.reactive import backpressure -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.mark.tool diff --git a/dimos/robot/unitree/type/__init__.py b/dimos/robot/unitree/type/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/type/lidar.py b/dimos/robot/unitree/type/lidar.py index df2909dc38..f58268d442 100644 --- a/dimos/robot/unitree/type/lidar.py +++ b/dimos/robot/unitree/type/lidar.py @@ -20,7 +20,7 @@ import numpy as np import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # Backwards compatibility alias for pickled data LidarMessage = PointCloud2 diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index 274115d516..da45c003f7 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -28,8 +28,8 @@ from dimos.mapping.pointclouds.accumulators.general import GeneralPointCloudAccumulator from dimos.mapping.pointclouds.accumulators.protocol import PointCloudAccumulator from dimos.mapping.pointclouds.occupancy import general_occupancy -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.go2.connection import Go2ConnectionProtocol diff --git a/dimos/robot/unitree/type/odometry.py b/dimos/robot/unitree/type/odometry.py index aa664b32ef..fabf800b6c 100644 --- a/dimos/robot/unitree/type/odometry.py +++ b/dimos/robot/unitree/type/odometry.py @@ -13,7 +13,9 @@ # limitations under the License. from typing import Literal, TypedDict -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.robot.unitree.type.timeseries import ( Timestamped, ) diff --git a/dimos/robot/unitree/type/test_lidar.py b/dimos/robot/unitree/type/test_lidar.py index 719088d77a..9a743d65b5 100644 --- a/dimos/robot/unitree/type/test_lidar.py +++ b/dimos/robot/unitree/type/test_lidar.py @@ -16,9 +16,9 @@ import itertools from typing import cast -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay def test_init() -> None: diff --git a/dimos/robot/unitree/type/test_odometry.py b/dimos/robot/unitree/type/test_odometry.py index d0fe2b290e..8020684fb7 100644 --- a/dimos/robot/unitree/type/test_odometry.py +++ b/dimos/robot/unitree/type/test_odometry.py @@ -17,7 +17,7 @@ import pytest from dimos.robot.unitree.type.odometry import Odometry -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay _EXPECTED_TOTAL_RAD = -4.05212 diff --git a/dimos/robot/unitree/unitree_skill_container.py b/dimos/robot/unitree/unitree_skill_container.py index d2f15b9efe..a79c061567 100644 --- a/dimos/robot/unitree/unitree_skill_container.py +++ b/dimos/robot/unitree/unitree_skill_container.py @@ -24,7 +24,9 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.navigation.base import NavigationState from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree_webrtc/type/__init__.py b/dimos/robot/unitree_webrtc/type/__init__.py deleted file mode 100644 index 03ff4f4563..0000000000 --- a/dimos/robot/unitree_webrtc/type/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Compatibility re-exports for legacy dimos.robot.unitree_webrtc.type.* imports.""" - -import importlib - -__all__ = [] - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - module = importlib.import_module("dimos.robot.unitree.type") - try: - return getattr(module, name) - except AttributeError as exc: - raise AttributeError(f"No {__name__} attribute {name}") from exc - - -def __dir__() -> list[str]: - module = importlib.import_module("dimos.robot.unitree.type") - return [name for name in dir(module) if not name.startswith("_")] diff --git a/dimos/rxpy_backpressure/__init__.py b/dimos/rxpy_backpressure/__init__.py deleted file mode 100644 index ff3b1f37c0..0000000000 --- a/dimos/rxpy_backpressure/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dimos.rxpy_backpressure.backpressure import BackPressure - -__all__ = [BackPressure] diff --git a/dimos/simulation/__init__.py b/dimos/simulation/__init__.py deleted file mode 100644 index 1a68191a36..0000000000 --- a/dimos/simulation/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Try to import Isaac Sim components -try: - from .isaac import IsaacSimulator, IsaacStream -except ImportError: - IsaacSimulator = None # type: ignore[assignment, misc] - IsaacStream = None # type: ignore[assignment, misc] - -# Try to import Genesis components -try: - from .genesis import GenesisSimulator, GenesisStream -except ImportError: - GenesisSimulator = None # type: ignore[assignment, misc] - GenesisStream = None # type: ignore[assignment, misc] - -__all__ = ["GenesisSimulator", "GenesisStream", "IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/base/__init__.py b/dimos/simulation/base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/simulation/engines/__init__.py b/dimos/simulation/engines/__init__.py deleted file mode 100644 index d437f9a7cd..0000000000 --- a/dimos/simulation/engines/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Simulation engines for manipulator backends.""" - -from __future__ import annotations - -from typing import Literal - -from dimos.simulation.engines.base import SimulationEngine -from dimos.simulation.engines.mujoco_engine import MujocoEngine - -EngineType = Literal["mujoco"] - -_ENGINES: dict[EngineType, type[SimulationEngine]] = { - "mujoco": MujocoEngine, -} - - -def get_engine(engine_name: EngineType) -> type[SimulationEngine]: - return _ENGINES[engine_name] - - -__all__ = [ - "EngineType", - "SimulationEngine", - "get_engine", -] diff --git a/dimos/simulation/engines/base.py b/dimos/simulation/engines/base.py index d450614c62..58e76ecba6 100644 --- a/dimos/simulation/engines/base.py +++ b/dimos/simulation/engines/base.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from pathlib import Path - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState class SimulationEngine(ABC): diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py index ddaaa25ad3..2d1cdf92ac 100644 --- a/dimos/simulation/engines/mujoco_engine.py +++ b/dimos/simulation/engines/mujoco_engine.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from pathlib import Path - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState logger = setup_logger() diff --git a/dimos/msgs/visualization_msgs/__init__.py b/dimos/simulation/engines/registry.py similarity index 56% rename from dimos/msgs/visualization_msgs/__init__.py rename to dimos/simulation/engines/registry.py index 0df5006c76..deadf3a404 100644 --- a/dimos/msgs/visualization_msgs/__init__.py +++ b/dimos/simulation/engines/registry.py @@ -12,8 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Visualization message types.""" +"""Simulation engine registry.""" -from dimos.msgs.visualization_msgs.EntityMarkers import EntityMarkers +from __future__ import annotations -__all__ = ["EntityMarkers"] +from typing import Literal + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.engines.mujoco_engine import MujocoEngine + +EngineType = Literal["mujoco"] + +_ENGINES: dict[EngineType, type[SimulationEngine]] = { + "mujoco": MujocoEngine, +} + + +def get_engine(engine_name: EngineType) -> type[SimulationEngine]: + return _ENGINES[engine_name] diff --git a/dimos/simulation/genesis/__init__.py b/dimos/simulation/genesis/__init__.py deleted file mode 100644 index 5657d9167b..0000000000 --- a/dimos/simulation/genesis/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import GenesisSimulator -from .stream import GenesisStream - -__all__ = ["GenesisSimulator", "GenesisStream"] diff --git a/dimos/simulation/isaac/__init__.py b/dimos/simulation/isaac/__init__.py deleted file mode 100644 index 2b9bdc082d..0000000000 --- a/dimos/simulation/isaac/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import IsaacSimulator -from .stream import IsaacStream - -__all__ = ["IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/manipulators/__init__.py b/dimos/simulation/manipulators/__init__.py deleted file mode 100644 index 816de0a18d..0000000000 --- a/dimos/simulation/manipulators/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simulation manipulator utilities.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface - from dimos.simulation.manipulators.sim_module import ( - SimulationModule, - SimulationModuleConfig, - simulation, - ) - -__all__ = [ - "SimManipInterface", - "SimulationModule", - "SimulationModuleConfig", - "simulation", -] - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - if name == "SimManipInterface": - from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface - - return SimManipInterface - if name in {"SimulationModule", "SimulationModuleConfig", "simulation"}: - from dimos.simulation.manipulators.sim_module import ( - SimulationModule, - SimulationModuleConfig, - simulation, - ) - - return { - "SimulationModule": SimulationModule, - "SimulationModuleConfig": SimulationModuleConfig, - "simulation": simulation, - }[name] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py index c829f0c864..6de570ae15 100644 --- a/dimos/simulation/manipulators/sim_manip_interface.py +++ b/dimos/simulation/manipulators/sim_manip_interface.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from dimos.hardware.manipulators.spec import ControlMode, JointLimits, ManipulatorInfo -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState if TYPE_CHECKING: from dimos.simulation.engines.base import SimulationEngine diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 20a55f1d02..5e873ba634 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -25,8 +25,10 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState -from dimos.simulation.engines import EngineType, get_engine +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.simulation.engines.registry import EngineType, get_engine from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 72408fefed..951d4790e3 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -17,7 +17,7 @@ import pytest -from dimos.protocol.rpc import RPCSpec +from dimos.protocol.rpc.spec import RPCSpec from dimos.simulation.manipulators.sim_module import SimulationModule diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py index 21baec473f..2644dddd36 100755 --- a/dimos/simulation/mujoco/mujoco_process.py +++ b/dimos/simulation/mujoco/mujoco_process.py @@ -29,7 +29,7 @@ import open3d as o3d # type: ignore[import-untyped] from dimos.core.global_config import GlobalConfig -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.simulation.mujoco.constants import ( DEPTH_CAMERA_FOV, LIDAR_FPS, diff --git a/dimos/simulation/mujoco/person_on_track.py b/dimos/simulation/mujoco/person_on_track.py index a816b5f3ee..f19b49e4c6 100644 --- a/dimos/simulation/mujoco/person_on_track.py +++ b/dimos/simulation/mujoco/person_on_track.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose class PersonPositionController: diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py index 6dad60b4b4..f677863edf 100644 --- a/dimos/simulation/mujoco/shared_memory.py +++ b/dimos/simulation/mujoco/shared_memory.py @@ -21,7 +21,7 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.simulation.mujoco.constants import VIDEO_HEIGHT, VIDEO_WIDTH from dimos.utils.logging_config import setup_logger diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py index 8b91ff817a..494b97ccbf 100644 --- a/dimos/simulation/sim_blueprints.py +++ b/dimos/simulation/sim_blueprints.py @@ -14,12 +14,10 @@ from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import ( # type: ignore[attr-defined] - JointCommand, - JointState, - RobotState, -) -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory from dimos.simulation.manipulators.sim_module import simulation from dimos.utils.data import LfsPath diff --git a/dimos/skills/__init__.py b/dimos/skills/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/skills/rest/__init__.py b/dimos/skills/rest/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/skills/unitree/__init__.py b/dimos/skills/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py deleted file mode 100644 index 1423bec9a1..0000000000 --- a/dimos/spec/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from dimos.spec.control import LocalPlanner -from dimos.spec.mapping import GlobalCostmap, GlobalPointcloud -from dimos.spec.nav import Nav -from dimos.spec.perception import Camera, Image, Pointcloud - -__all__ = [ - "Camera", - "GlobalCostmap", - "GlobalPointcloud", - "Image", - "LocalPlanner", - "Nav", - "Pointcloud", -] diff --git a/dimos/spec/control.py b/dimos/spec/control.py index 48d58a926a..b597b4faaf 100644 --- a/dimos/spec/control.py +++ b/dimos/spec/control.py @@ -15,7 +15,7 @@ from typing import Protocol from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs.Twist import Twist class LocalPlanner(Protocol): diff --git a/dimos/spec/mapping.py b/dimos/spec/mapping.py index 0ba88cfaa9..f35778f40b 100644 --- a/dimos/spec/mapping.py +++ b/dimos/spec/mapping.py @@ -15,8 +15,8 @@ from typing import Protocol from dimos.core.stream import Out -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 class GlobalPointcloud(Protocol): diff --git a/dimos/spec/nav.py b/dimos/spec/nav.py index 08f6f42b35..ae971e7b5c 100644 --- a/dimos/spec/nav.py +++ b/dimos/spec/nav.py @@ -15,8 +15,9 @@ from typing import Protocol from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.nav_msgs import Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.Path import Path class Nav(Protocol): diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 1cfe352390..4fac65ad02 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -16,7 +16,10 @@ from dimos.core.stream import Out from dimos.msgs.nav_msgs.Odometry import Odometry as OdometryMsg -from dimos.msgs.sensor_msgs import CameraInfo, Image as ImageMsg, Imu, PointCloud2 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image as ImageMsg +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 class Image(Protocol): diff --git a/dimos/stream/__init__.py b/dimos/stream/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/stream/audio/__init__.py b/dimos/stream/audio/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/stream/video_providers/__init__.py b/dimos/stream/video_providers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/teleop/__init__.py b/dimos/teleop/__init__.py deleted file mode 100644 index 8324113111..0000000000 --- a/dimos/teleop/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Teleoperation modules for DimOS.""" diff --git a/dimos/teleop/keyboard/__init__.py b/dimos/teleop/keyboard/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/teleop/keyboard/keyboard_teleop_module.py b/dimos/teleop/keyboard/keyboard_teleop_module.py index 854c0fbc22..a90dc3cf44 100644 --- a/dimos/teleop/keyboard/keyboard_teleop_module.py +++ b/dimos/teleop/keyboard/keyboard_teleop_module.py @@ -44,7 +44,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" diff --git a/dimos/teleop/phone/__init__.py b/dimos/teleop/phone/__init__.py deleted file mode 100644 index 552032a47b..0000000000 --- a/dimos/teleop/phone/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Phone teleoperation module for DimOS.""" - -from dimos.teleop.phone.phone_extensions import ( - SimplePhoneTeleop, - simple_phone_teleop_module, -) -from dimos.teleop.phone.phone_teleop_module import ( - PhoneTeleopConfig, - PhoneTeleopModule, - phone_teleop_module, -) - -__all__ = [ - "PhoneTeleopConfig", - "PhoneTeleopModule", - "SimplePhoneTeleop", - "phone_teleop_module", - "simple_phone_teleop_module", -] diff --git a/dimos/teleop/phone/phone_extensions.py b/dimos/teleop/phone/phone_extensions.py index 0f52fce2e0..c5cdc1fc80 100644 --- a/dimos/teleop/phone/phone_extensions.py +++ b/dimos/teleop/phone/phone_extensions.py @@ -20,7 +20,9 @@ """ from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.teleop.phone.phone_teleop_module import PhoneTeleopModule diff --git a/dimos/teleop/phone/phone_teleop_module.py b/dimos/teleop/phone/phone_teleop_module.py index cc55f1f180..3f32063cce 100644 --- a/dimos/teleop/phone/phone_teleop_module.py +++ b/dimos/teleop/phone/phone_teleop_module.py @@ -36,7 +36,9 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs.Bool import Bool from dimos.utils.logging_config import setup_logger from dimos.utils.path_utils import get_project_root diff --git a/dimos/teleop/quest/__init__.py b/dimos/teleop/quest/__init__.py deleted file mode 100644 index 83daf4347b..0000000000 --- a/dimos/teleop/quest/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Quest teleoperation module.""" - -from dimos.teleop.quest.quest_extensions import ( - ArmTeleopModule, - TwistTeleopModule, - VisualizingTeleopModule, - arm_teleop_module, - twist_teleop_module, - visualizing_teleop_module, -) -from dimos.teleop.quest.quest_teleop_module import ( - Hand, - QuestTeleopConfig, - QuestTeleopModule, - QuestTeleopStatus, - quest_teleop_module, -) -from dimos.teleop.quest.quest_types import ( - Buttons, - QuestControllerState, - ThumbstickState, -) - -__all__ = [ - "ArmTeleopModule", - "Buttons", - "Hand", - "QuestControllerState", - "QuestTeleopConfig", - "QuestTeleopModule", - "QuestTeleopStatus", - "ThumbstickState", - "TwistTeleopModule", - "VisualizingTeleopModule", - # Blueprints - "arm_teleop_module", - "quest_teleop_module", - "twist_teleop_module", - "visualizing_teleop_module", -] diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index ac86a0325f..a3aa54ee08 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -22,7 +22,7 @@ ) from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import arm_teleop_module, visualizing_teleop_module from dimos.teleop.quest.quest_types import Buttons diff --git a/dimos/teleop/quest/quest_extensions.py b/dimos/teleop/quest/quest_extensions.py index c92ac55a43..46e868837d 100644 --- a/dimos/teleop/quest/quest_extensions.py +++ b/dimos/teleop/quest/quest_extensions.py @@ -25,7 +25,8 @@ from pydantic import Field from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.teleop.quest.quest_teleop_module import Hand, QuestTeleopConfig, QuestTeleopModule from dimos.teleop.quest.quest_types import Buttons, QuestControllerState from dimos.teleop.utils.teleop_visualization import ( diff --git a/dimos/teleop/quest/quest_teleop_module.py b/dimos/teleop/quest/quest_teleop_module.py index 3c8e6e9812..5868aab620 100644 --- a/dimos/teleop/quest/quest_teleop_module.py +++ b/dimos/teleop/quest/quest_teleop_module.py @@ -37,8 +37,8 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Joy +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Joy import Joy from dimos.teleop.quest.quest_types import Buttons, QuestControllerState from dimos.teleop.utils.teleop_transforms import webxr_to_robot from dimos.utils.logging_config import setup_logger diff --git a/dimos/teleop/quest/quest_types.py b/dimos/teleop/quest/quest_types.py index 7fd991a76c..7e7cfc7620 100644 --- a/dimos/teleop/quest/quest_types.py +++ b/dimos/teleop/quest/quest_types.py @@ -18,8 +18,8 @@ from dataclasses import dataclass, field from typing import ClassVar -from dimos.msgs.sensor_msgs import Joy -from dimos.msgs.std_msgs import UInt32 +from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.std_msgs.UInt32 import UInt32 @dataclass diff --git a/dimos/teleop/utils/__init__.py b/dimos/teleop/utils/__init__.py deleted file mode 100644 index ae8c375e8f..0000000000 --- a/dimos/teleop/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Teleoperation utilities.""" diff --git a/dimos/teleop/utils/teleop_transforms.py b/dimos/teleop/utils/teleop_transforms.py index 15fd3be120..f1e9e9381d 100644 --- a/dimos/teleop/utils/teleop_transforms.py +++ b/dimos/teleop/utils/teleop_transforms.py @@ -22,7 +22,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.transform_utils import matrix_to_pose, pose_to_matrix if TYPE_CHECKING: diff --git a/dimos/teleop/utils/teleop_visualization.py b/dimos/teleop/utils/teleop_visualization.py index a59b0666ef..5a7acd06e9 100644 --- a/dimos/teleop/utils/teleop_visualization.py +++ b/dimos/teleop/utils/teleop_visualization.py @@ -24,7 +24,7 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() diff --git a/dimos/perception/experimental/temporal_memory/__init__.py b/dimos/test_no_init_files.py similarity index 50% rename from dimos/perception/experimental/temporal_memory/__init__.py rename to dimos/test_no_init_files.py index 1056e82e8b..39efb7ad24 100644 --- a/dimos/perception/experimental/temporal_memory/__init__.py +++ b/dimos/test_no_init_files.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Temporal memory package.""" +from dimos.constants import DIMOS_PROJECT_ROOT -from .frame_window_accumulator import Frame, FrameWindowAccumulator -from .temporal_memory import TemporalMemory, TemporalMemoryConfig, temporal_memory -from .temporal_state import TemporalState -from .window_analyzer import WindowAnalyzer -__all__ = [ - "Frame", - "FrameWindowAccumulator", - "TemporalMemory", - "TemporalMemoryConfig", - "TemporalState", - "WindowAnalyzer", - "temporal_memory", -] +def test_no_init_files(): + dimos_dir = DIMOS_PROJECT_ROOT / "dimos" + init_files = sorted(dimos_dir.rglob("__init__.py")) + if init_files: + listing = "\n".join(f" - {f.relative_to(dimos_dir)}" for f in init_files) + raise AssertionError( + f"Found __init__.py files in dimos/:\n{listing}\n\n" + "__init__.py files are not allowed because they lead to unnecessary " + "extraneous imports. Everything should be imported straight from the " + "source module." + ) diff --git a/dimos/types/ros_polyfill.py b/dimos/types/ros_polyfill.py index 4bad99740d..70140336b8 100644 --- a/dimos/types/ros_polyfill.py +++ b/dimos/types/ros_polyfill.py @@ -15,7 +15,7 @@ try: from geometry_msgs.msg import Vector3 # type: ignore[attr-defined] except ImportError: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 try: from geometry_msgs.msg import ( # type: ignore[attr-defined] diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 7de82e8f9a..e62b275dfc 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -20,7 +20,7 @@ from reactivex.scheduler import ThreadPoolScheduler from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import ( Timestamped, TimestampedBufferCollection, @@ -28,9 +28,9 @@ to_datetime, to_ros_stamp, ) -from dimos.utils import testing from dimos.utils.data import get_data from dimos.utils.reactive import backpressure +from dimos.utils.testing.replay import TimedSensorReplay def test_timestamped_dt_method() -> None: @@ -296,7 +296,7 @@ def spy(image): # sensor reply of raw video frames video_raw = ( - testing.TimedSensorReplay( + TimedSensorReplay( "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() ) .stream(speed) diff --git a/dimos/utils/cli/__init__.py b/dimos/utils/cli/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/utils/cli/agentspy/demo_agentspy.py b/dimos/utils/cli/agentspy/demo_agentspy.py index 5229295038..851229131b 100755 --- a/dimos/utils/cli/agentspy/demo_agentspy.py +++ b/dimos/utils/cli/agentspy/demo_agentspy.py @@ -24,7 +24,7 @@ ToolMessage, ) -from dimos.protocol.pubsub import lcm # type: ignore[attr-defined] +import dimos.protocol.pubsub.impl.lcmpubsub as lcm from dimos.protocol.pubsub.impl.lcmpubsub import PickleLCM diff --git a/dimos/utils/decorators/__init__.py b/dimos/utils/decorators/__init__.py deleted file mode 100644 index d0f91a4939..0000000000 --- a/dimos/utils/decorators/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Decorators and accumulators for rate limiting and other utilities.""" - -from .accumulators import Accumulator, LatestAccumulator, RollingAverageAccumulator -from .decorators import CachedMethod, limit, retry, simple_mcache, ttl_cache - -__all__ = [ - "Accumulator", - "CachedMethod", - "LatestAccumulator", - "RollingAverageAccumulator", - "limit", - "retry", - "simple_mcache", - "ttl_cache", -] diff --git a/dimos/utils/decorators/test_decorators.py b/dimos/utils/decorators/test_decorators.py index 98545a2e37..8923151667 100644 --- a/dimos/utils/decorators/test_decorators.py +++ b/dimos/utils/decorators/test_decorators.py @@ -16,7 +16,8 @@ import pytest -from dimos.utils.decorators import RollingAverageAccumulator, limit, retry, simple_mcache, ttl_cache +from dimos.utils.decorators.accumulators import RollingAverageAccumulator +from dimos.utils.decorators.decorators import limit, retry, simple_mcache, ttl_cache def test_limit() -> None: diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index 42374029f2..84b91acf79 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -34,7 +34,7 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import JpegLcmTransport, LCMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.utils.fast_image_generator import random_image diff --git a/dimos/utils/docs/test_doclinks.py b/dimos/utils/docs/test_doclinks.py index 7da6a6281b..a5a50b03e5 100644 --- a/dimos/utils/docs/test_doclinks.py +++ b/dimos/utils/docs/test_doclinks.py @@ -16,7 +16,9 @@ from pathlib import Path -from doclinks import ( +import pytest + +from dimos.utils.docs.doclinks import ( build_doc_index, build_file_index, extract_other_backticks, @@ -27,7 +29,6 @@ score_path_similarity, split_by_ignore_regions, ) -import pytest # Use the actual repo root REPO_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py index 4397e0171e..623556d6b7 100644 --- a/dimos/utils/reactive.py +++ b/dimos/utils/reactive.py @@ -24,7 +24,7 @@ from reactivex.observable import Observable from reactivex.scheduler import ThreadPoolScheduler -from dimos.rxpy_backpressure import BackPressure +from dimos.rxpy_backpressure.backpressure import BackPressure from dimos.utils.threadpool import get_scheduler T = TypeVar("T") diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py index 7923124c9f..77852a7bb2 100644 --- a/dimos/utils/test_transform_utils.py +++ b/dimos/utils/test_transform_utils.py @@ -16,7 +16,10 @@ import pytest from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils import transform_utils diff --git a/dimos/utils/testing/__init__.py b/dimos/utils/testing/__init__.py deleted file mode 100644 index 568cd3604f..0000000000 --- a/dimos/utils/testing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "moment": ["Moment", "OutputMoment", "SensorMoment"], - "replay": ["SensorReplay", "TimedSensorReplay", "TimedSensorStorage"], - }, -) diff --git a/dimos/utils/testing/test_moment.py b/dimos/utils/testing/test_moment.py index 75f11d2657..dcca3d7d01 100644 --- a/dimos/utils/testing/test_moment.py +++ b/dimos/utils/testing/test_moment.py @@ -14,9 +14,12 @@ import time from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.protocol.tf.tf import TF from dimos.robot.unitree.go2 import connection from dimos.utils.data import get_data from dimos.utils.testing.moment import Moment, SensorMoment diff --git a/dimos/utils/testing/test_replay.py b/dimos/utils/testing/test_replay.py index e3020777b4..10ace353f7 100644 --- a/dimos/utils/testing/test_replay.py +++ b/dimos/utils/testing/test_replay.py @@ -16,7 +16,7 @@ from reactivex import operators as ops -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index ed82f6116f..bfd38ce14f 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -16,7 +16,10 @@ import numpy as np from scipy.spatial.transform import Rotation as R # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def normalize_angle(angle: float) -> float: diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 6729f143cd..12f998d96d 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -38,7 +38,8 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable diff --git a/dimos/web/__init__.py b/dimos/web/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/web/dimos_interface/__init__.py b/dimos/web/dimos_interface/__init__.py deleted file mode 100644 index 3bdc622cee..0000000000 --- a/dimos/web/dimos_interface/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Dimensional Interface package -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "api.server": ["FastAPIServer"], - }, -) diff --git a/dimos/web/dimos_interface/api/__init__.py b/dimos/web/dimos_interface/api/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/web/websocket_vis/costmap_viz.py b/dimos/web/websocket_vis/costmap_viz.py index 21309c94bc..f24628e6c7 100644 --- a/dimos/web/websocket_vis/costmap_viz.py +++ b/dimos/web/websocket_vis/costmap_viz.py @@ -19,7 +19,7 @@ import numpy as np -from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid class CostmapViz: diff --git a/dimos/web/websocket_vis/path_history.py b/dimos/web/websocket_vis/path_history.py index 39b6be08a3..c69e7e9508 100644 --- a/dimos/web/websocket_vis/path_history.py +++ b/dimos/web/websocket_vis/path_history.py @@ -17,7 +17,7 @@ This is a minimal implementation to support websocket visualization. """ -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class PathHistory: diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 7a5c9587e1..5514144570 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -48,11 +48,15 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out +from dimos.mapping.models import LatLon from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.inflation import simple_inflate -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger from .optimized_costmap import OptimizedCostmapEncoder diff --git a/docs/capabilities/navigation/native/index.md b/docs/capabilities/navigation/native/index.md index a750d3bfba..6a8c5224e9 100644 --- a/docs/capabilities/navigation/native/index.md +++ b/docs/capabilities/navigation/native/index.md @@ -118,7 +118,7 @@ All visualization layers shown together ## Blueprint Composition -The navigation stack is composed in the [`unitree_go2`](/dimos/robot/unitree/go2/blueprints/__init__.py) blueprint: +The navigation stack is composed in the [`unitree_go2`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py) blueprint: ```python fold output=assets/go2_blueprint.svg from dimos.core.blueprints import autoconnect diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index b930671906..db931872bd 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -81,7 +81,7 @@ We’ll go through these layers top-down. See [Blueprints](/docs/usage/blueprints.md) for the blueprint API. -From [`unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py). +From [`unitree/go2/blueprints/smart/unitree_go2.py`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py). Example: rebind a few streams from the default `LCMTransport` to `ROSTransport` (defined at [`transport.py`](/dimos/core/transport.py#L226)) so you can visualize in **rviz2**. diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 809f7881e4..57ad460354 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -96,7 +96,7 @@ This happens on lower-end hardware (NUC, older laptops) with large maps. ### Increase Voxel Size -Edit [`dimos/robot/unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py) line 82: +Edit [`dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py): ```python # Before (high detail, slower on large maps) diff --git a/examples/simplerobot/simplerobot.py b/examples/simplerobot/simplerobot.py index 2a1867b37c..517684d7cd 100644 --- a/examples/simplerobot/simplerobot.py +++ b/examples/simplerobot/simplerobot.py @@ -30,7 +30,11 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def apply_twist(pose: Pose, twist: Twist, dt: float) -> Pose: diff --git a/pyproject.toml b/pyproject.toml index 4370944b27..722e3b0485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -377,6 +377,7 @@ python_version = "3.12" incremental = true strict = true warn_unused_ignores = false +explicit_package_bases = true exclude = "^dimos/models/Detic(/|$)|^dimos/rxpy_backpressure(/|$)|.*/test_.|.*/conftest.py*" [[tool.mypy.overrides]] @@ -429,7 +430,7 @@ env = [ "GOOGLE_MAPS_API_KEY=AIzafake_google_key", "PYTHONWARNINGS=ignore:cupyx.jit.rawkernel is experimental:FutureWarning", ] -addopts = "-v -r a -p no:warnings --color=yes -m 'not (tool or slow or mujoco)'" +addopts = "-v -r a -p no:warnings -p no:launch_testing -p no:launch_ros --import-mode=importlib --color=yes -m 'not (tool or slow or mujoco)'" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" From 1ec4227c41be450533b3090931e85205ff7901f3 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 14 Mar 2026 15:30:23 -0700 Subject: [PATCH 10/42] fix(deps): skip pyrealsense2 on macOS (#1556) * fix(deps): skip pyrealsense2 on macOS (not available) * - --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 722e3b0485..1757e01a8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,7 +202,7 @@ manipulation = [ # Hardware SDKs "piper-sdk", - "pyrealsense2", + "pyrealsense2; sys_platform != 'darwin'", "xarm-python-sdk>=1.17.0", # Visualization (Optional) diff --git a/uv.lock b/uv.lock index 5ec39fff59..e6ba8198a8 100644 --- a/uv.lock +++ b/uv.lock @@ -1879,7 +1879,7 @@ manipulation = [ { name = "matplotlib" }, { name = "piper-sdk" }, { name = "plotly" }, - { name = "pyrealsense2" }, + { name = "pyrealsense2", marker = "sys_platform != 'darwin'" }, { name = "pyyaml" }, { name = "xacro" }, { name = "xarm-python-sdk" }, @@ -2066,7 +2066,7 @@ requires-dist = [ { name = "pydantic-settings", marker = "extra == 'docker'", specifier = ">=2.11.0,<3" }, { name = "pygame", marker = "extra == 'sim'", specifier = ">=2.6.1" }, { name = "pymavlink", marker = "extra == 'drone'" }, - { name = "pyrealsense2", marker = "extra == 'manipulation'" }, + { name = "pyrealsense2", marker = "sys_platform != 'darwin' and extra == 'manipulation'" }, { name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==0.26.0" }, { name = "pytest-env", marker = "extra == 'dev'", specifier = "==1.1.5" }, From b3177fde166069974df037a0c4906c52b9336bbe Mon Sep 17 00:00:00 2001 From: leshy Date: Sun, 15 Mar 2026 07:21:56 +0200 Subject: [PATCH 11/42] Feat/memory2 (#1536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * memory plans * spec iteration * spec iteration * query objects spec * mem3 iteration * live/passive transforms * initial pass on memory * transform materialize * sqlite schema: decomposed pose columns, separate payload table, R*Tree spatial index, lazy data loading - Pose stored as 7 real columns (x/y/z + quaternion) instead of blob, enabling R*Tree spatial indexing - Payload moved to separate {name}_payload table with lazy loading via _data_loader closure - R*Tree virtual table created per stream for .near() bounding-box queries - Added __iter__ to Stream for lazy iteration via fetch_pages - Added embedding_stream() to Session ABC - Updated _streams metadata with parent_stream and embedding_dim columns - Codec module extracted (LcmCodec, PickleCodec, codec_for_type) - Fixed broken memory_old.timeseries imports (memory.timeseries → memory_old.timeseries) - Tests now use real Image data from TimedSensorReplay("unitree_go2_bigoffice/video") - 32/32 tests passing, mypy clean * JpegCodec for Image storage (43x smaller), ingest helpers, QualityWindowTransformer, E2E test - Add JpegCodec as default codec for Image types (2.76MB → 64KB per frame) - Preserve frame_id in JPEG header; ts stored in meta table - Add ingest() helper for bulk-loading (ts, payload) iterables into streams - Add QualityWindowTransformer: best-frame-per-window (supports backfill + live) - EmbeddingTransformer sets output_type=Embedding automatically - Require payload_type when creating new streams (no silent PickleCodec fallback) - TransformStream.store() accepts payload_type, propagated through materialize_transform - E2E test: 5min video → sharpness filter → CLIP embed → text search - Move test_sqlite.py next to sqlite.py, update Image comparisons for lossy codec - Add sqlite-vec dependency * Wire parent_id lineage through transforms for automatic source data projection - Add parent_id to Observation, append(), do_append(), and _META_COLS - All transformers (PerItem, QualityWindow, Embedding) pass obs.id as parent_id - SqliteEmbeddingBackend._row_to_obs() wires _source_data_loader via parent_id - EmbeddingObservation.data now auto-projects to parent stream's payload (e.g. Image) - No more timestamp-matching hacks to find source data from embedding results * Wire parent_stream into _streams registry, add tasks.md gap analysis - materialize_transform() now UPDATEs _streams.parent_stream so stream-level lineage is discoverable (prerequisite for .join()) - Fix mypy: narrow parent_table type in _source_loader closure - Add plans/memory/tasks.md documenting all spec-vs-impl gaps * Implement project_to() for cross-stream lineage projection Adds LineageFilter that compiles to nested SQL subqueries walking the parent_id chain. project_to(target) returns a chainable target Stream using the same _with_filter mechanism as .after(), .near(), etc. Also fixes _session propagation in search_embedding/search_text. * Make search_embedding auto-project to source stream EmbeddingStream is a semantic index — search results should be source observations (Images), not Embedding objects. search_embedding now auto-projects via project_to when lineage exists, falling back to EmbeddingStream for standalone streams without parent lineage. * CaptionTransformer + Florence2 batch fix - Add CaptionTransformer: wraps Captioner/VlModel, uses caption_batch() for backfill efficiency, auto-creates TextStream with FTS on .store() - Fix Florence2 caption_batch() emitting tokens (skip_special_tokens) - E2E script now uses transform pipeline for captioning search results * ObservationSet: fetch() returns list-like + stream-like result set fetch() now returns ObservationSet instead of plain list, keeping you in the Stream API. This enables fork-and-zip (one DB query, two uses) and in-memory re-filtering without re-querying the database. - Add matches(obs) to all filter dataclasses for in-Python evaluation - Add ListBackend (in-memory StreamBackend) and ObservationSet class - Filtered .appended reactive subscription via matches() infrastructure - Update e2e export script to use fork-and-zip pattern - 20 new tests (64 total, all passing) * search_embedding accepts str/image with auto-embedding EmbeddingStream now holds an optional model reference, so search_embedding auto-dispatches: str → embed_text(), image → embed(), Embedding/list[float] → use directly. The model is wired through materialize_transform and also accepted via embedding_stream(). * Add sqlite_vec to mypy ignore list (no type stubs available) * Fix mypy + pytest errors across memory and memory_old modules - Fix SpatialImage/SpatialEntry dataclass hierarchy in memory_old - Fix import path in memory_old/test_embedding.py - Add None guard for obs.ts in run_viz_demo.py - Add payload_type/session kwargs to base Stream.store() signature - Type-annotate embeddings as EmbeddingStream in run_e2e_export.py - Add similarity scores, raw search mode, pose ingest, viz pipeline * Improve similarity heatmap with normalized values and distance spread - Normalize similarity scores relative to min/max (CLIP clusters in narrow band) - Add distance_transform_edt spread so dots radiate outward, fading to 0 - Bump default search k to 200 for denser heatmaps * Remove plans/ from tracking (kept locally) * Address Greptile review: SQL injection guards, distance ordering, stubs - Validate stream names and tag keys as SQL identifiers - Allowlist order_by fields to {id, ts} - Re-sort vector search results by distance rank after IN-clause fetch - Make TagsFilter hashable (tuple of pairs instead of dict) - Remove dead code in memory_old/embedding.py - Add scipy-stubs, fix distance_transform_edt type annotations * Add memory Rerun visualization, fix stream iteration, update docs - Add dimos/memory/rerun.py: to_rerun() sends stream data to Rerun with auto-derived entity paths and no wall-clock timeline contamination - Fix Stream.fetch_pages() to respect limit_val (was always overridden by batch_size, making .limit() ineffective during iteration) - Update viz.py: normalize similarities with 20% floor cutoff, sort timeline by timestamp, add log_top_images() - Convert run_e2e_export.py to pytest with cached DB fixture - Update plans/memory docs to match current implementation * Rename run_e2e_export → test_e2e_export, delete viz.py + run_viz_demo, fix mypy - Rename to test_e2e_export.py (it's a pytest file, not a standalone script) - Fix Generator return type and type: ignore for mypy - Delete viz.py (replaced by rerun.py) and run_viz_demo.py - Update docs/api.md to reference rerun.py instead of viz.py * added docs * removed tasks.md * Optimize memory pipeline: TurboJPEG codec, sharpness downsample, thread reduction - Switch JpegCodec from cv2.imencode to TurboJPEG (2-5x faster encode/decode) - Lower default JPEG quality from 90 to 50 for smaller storage footprint - Downscale sharpness computation to 160px Laplacian variance (10-20x cheaper) - Add MemoryModule with plain-Python sharpness windowing (no rx timer overhead) - Limit OpenCV threads: 2 globally in worker entrypoint, 1 in MemoryModule - Cap global rx ThreadPoolScheduler at 8 workers (was unbounded cpu_count) - Refactor SqliteEmbeddingBackend/SqliteTextBackend to use _post_insert hook - Encode payload before meta insert to prevent orphaned rows on codec error - Add `dimos ps` CLI command and `dps` entrypoint for non-interactive process listing - Add unitree-go2-memory blueprint * text embedding transformer * cleanup * Use Codec protocol type instead of concrete union, remove dead _pose_codec * correct db sessions * record module cleanup * memory elements are now Resource, simplification of memory Module * Rename stream.appended to stream.observable()/subscribe() Mirror the core In stream API — memory streams now expose .observable() and .subscribe() instead of the .appended property. * repr, embedding fetch simplification * Make Observation generic: Observation[T] with full type safety * Simplify Stream._clone with copy.copy, remove subclass overrides * loader refactor * Extract backend.load_data(), add stream.load_data(obs) public API SQL now lives on the backend, closures are thin thread-guarded wrappers. * Add rich colored __str__ to Stream and Filter types print() now shows colored output (class=cyan, type=yellow, name=green, filters=cyan, pipes=dim). __repr__ stays plain for logs. * Unify __repr__ and __str__ via _rich_text().plain, remove duplicate rendering * renamed types to type * one -> first, time range * getitem for streams * readme sketch * bigoffice db in lfs, sqlite accepts Path * projection transformers * stream info removed, stream accessor helper, TS unique per stream * Add colored summary() output and model= param to search_embedding summary() now renders the rich-text stream header with colored type info, count, timestamps, and duration. search_embedding() accepts an optional model= override so callers don't need to attach a model to the stream. * stream delete * florence model detail settings and prefix filter * extracted formatting to a separate file * extract rich text rendering to formatting.py, add Stream.name, fix stale tests Move all _rich_text methods from type.py and stream.py into a central formatting.py module with a single rich_text() dispatch function. Replace relative imports with absolute imports across memory/. Add Stream.name property, remove VLMDetectionTransformer tests, fix stale test assertions. * matching based on streams * projection experiments * projection bugfix * observationset typing fix * detections, cleanup * mini adjustments * transform chaining * memory2: lazy pull-based stream system Greenfield rewrite of the memory module using sync generators. Every .filter(), .transform(), .map() returns a new Stream — no computation until iteration. Backends handle query application; transforms are Iterator[Obs] → Iterator[Obs]. Live mode with backpressure buffers bridges push sources to pull consumers. * memory2: fix typing — zero type:ignore, proper generics - Closed → ClosedError (N818) - Callable types for _loader, Disposable.fn, backend_factory, PredicateFilter.fn - Disposable typed in stream._live_sub - assert+narrowing instead of type:ignore in KeepLast.take, _iter_transform - cast only in Session.stream (unavoidable generic cache lookup) * memory2: fix .live() on transform streams — reject with clear error Live items from the backend buffer were bypassing the transform chain entirely. The fix: .live() is only valid on backend-backed streams; transforms downstream just see an infinite iterator. * memory2: replace custom Disposable with rxpy DisposableBase Use reactivex.abc.DisposableBase in protocols and reactivex.disposable.Disposable in implementations, consistent with dimos's existing Resource pattern. * memory2: extract filters and StreamQuery from type.py into filter.py type.py now only contains Observation and its helpers. * memory2: store transform on Stream node, not as source tuple Stream._source is now `Backend | Stream` instead of `Backend | tuple[Stream, Transformer]`. The transformer lives on the stream that owns it (`_xf` field), not bundled into the source pointer. Fix .map() tests to pass Observation→Observation lambdas. Remove live mode tests (blocked by nvidia driver D-state in root conftest autoconf). * memory2: move live logic from Stream into Backend via StreamQuery Live is now just a query parameter (live_buffer on StreamQuery). Stream.live() is a one-liner query modifier — the backend handles subscription, dedup, and backpressure internally. Stream has zero live implementation. * memory2: extract impl/ layer with MemoryStore and SqliteStore scaffold Move ListBackend from backend.py into impl/memory.py alongside new MemorySession and MemoryStore. Add SqliteStore/SqliteSession/SqliteBackend skeleton in impl/sqlite.py. Refactor Store and Session to abstract base classes with _create_backend() hook. backend.py now only contains the Backend and LiveBackend protocols. Also fix doclinks: disambiguate memory.py reference in transports docs, and include source .md file path in all doclinks error messages. * memory2: add buffer.py docstring and extract buffer tests to test_buffer.py * memory2: add Codec protocol and grid test for store implementations Introduce codecs/ package with the Codec[T] protocol (encode/decode). Thread payload_type through Session._create_backend() so backends can select the right codec. Add test_impl.py grid test that runs the same 15 basic tests against every store backend (memory passes, sqlite xfail until implemented). * memory2: add codec implementations (pickle, lcm, jpeg) with grid tests PickleCodec for arbitrary objects, LcmCodec for DimosMsg types, JpegCodec for Image types with TurboJPEG. codec_for() auto-selects based on payload type. Grid test verifies roundtrip preservation across all three codecs using real PoseStamped and camera frame data. * resource: add context manager to Resource; make Store/Session Resources Resource.__enter__/__exit__ calls start()/stop(), giving every Resource context-manager support. memory2 Store and Session now extend Resource instead of bare ABC, replacing close() with the standard start()/stop() lifecycle. * resource: add CompositeResource with owned disposables CompositeResource extends Resource with a _disposables list and own() method. stop() disposes all children — gives tree-structured resources automatic cleanup. Session and Store now extend CompositeResource. * memory2: add BlobStore ABC with File and SQLite implementations BlobStore separates payload blob storage from metadata indexing. FileBlobStore stores on disk ({root}/{stream}/{key}.bin), SqliteBlobStore uses per-stream tables. Grid tests cover both. * memory2: move blobstore.md into blobstore/ as module readme * memory2: add embedding layer, vector/text search, live safety guards - EmbeddedObservation with derive() promotion semantics - EmbedImages/EmbedText transformers using EmbeddingModel ABC - .search(vec, k) and .search_text() on Stream with Embedding type - VectorStore ABC for pluggable vector backends - Backend.append() takes Observation directly (not kwargs) - is_live() walks source chain; search/order_by/fetch/count guard against live streams with TypeError instead of silent hang - .drain() terminal for constant-memory side-effect pipelines - Rewrite test_stream.py to use Stream layer (no manual backends) * memory2: add documentation for streaming model, codecs, and backends - README.md: architecture overview, module index, quick start - streaming.md: lazy vs materializing vs terminal evaluation model - codecs/README.md: codec protocol, built-in codecs, writing new ones - impl/README.md: backend guide with query contract and grid test setup * query application refactor * memory2: replace LiveBackend with pluggable LiveChannel, add Configurable pattern - Replace LiveBackend protocol with LiveChannel ABC (SubjectChannel for in-memory fan-out, extensible to Redis/Postgres for cross-process) - Add livechannel/ subpackage with SubjectChannel implementation - Make Store and Session extend Configurable[ConfigT] with StoreConfig and SessionConfig dataclasses - Remove redundant Session._backends dict (Backend lives in Stream._source) - Make list_streams() and delete_stream() abstract on Session so implementations can query persisted streams - StreamNamespace delegates to list_streams()/stream() instead of accessing _streams directly - Remove LiveBackend isinstance guard from stream.py — all backends now have a built-in LiveChannel * memory2: make backends Configurable, add session→stream config propagation Session.stream() now merges session-level defaults with per-stream overrides and forwards them to _create_backend(). Backends (ListBackend, SqliteBackend) extend Configurable[BackendConfig] so they receive live_channel, blob_store, and vector_store through the standard config pattern instead of explicit constructor params. * memory2: wire VectorStore into ListBackend, add MemoryVectorStore ListBackend.append() now delegates embedding storage to the pluggable VectorStore when configured. _iterate_snapshot() uses VectorStore.search() for ANN ranking when available, falling back to brute-force in StreamQuery.apply(). Adds MemoryVectorStore (in-memory brute-force impl) and tests verifying end-to-end config propagation including per-stream vector_store overrides. * memory2: wire BlobStore into ListBackend with lazy/eager blob loading Payloads are encoded via auto-selected codec and externalized to the pluggable BlobStore on append. Observations become lightweight metadata with lazy loaders that fetch+decode on first .data access. Per-stream eager_blobs toggle pre-loads data during iteration. * memory2: allow bare generator functions as stream transforms stream.transform() now accepts Iterator→Iterator callables in addition to Transformer subclasses, for quick stateful pipelines. * memory2: update docs to reflect current API - impl/README: LiveBackend → LiveChannel, add Configurable pattern, update _create_backend and Store/Session signatures - embeddings.md: fix Observation fields (_source → _loader), embedding type (np.ndarray → Embedding), remove unimplemented source chain, use temporal join for lineage - streaming.md: note .transform() accepts bare callables - README: add FnIterTransformer, generator function example * memory2: implement full SqliteBackend with vec0 vector search, JSONB tags, and SQL filter pushdown - Add SqliteVectorStore using sqlite-vec vec0 virtual tables with cosine distance - Implement SqliteBackend: append, iterate (snapshot/live/vector), count with SQL pushdown - Add SQL filter compilation for time, tags, and range filters; Python fallback for NearFilter/PredicateFilter - Wire SqliteSession with _streams registry table, codec persistence, shared store auto-wiring - Support eager blob loading via co-located JOIN optimization - Load sqlite-vec extension in SqliteStore with graceful fallback - Remove xfail markers from test_impl.py — all 36 grid tests pass * memory2: stream rows via cursor pagination instead of fetchall() Add configurable page_size (default 256) to BackendConfig. SqliteBackend now iterates the cursor with arraysize set to page_size for memory-efficient streaming of large result sets. * memory2: add lazy/eager blob tests and spy store delegation grid tests - TestBlobLoading: verify lazy (_UNLOADED sentinel + loader) vs eager (JOIN inline) paths for SqliteBackend, plus value equivalence between both modes - TestStoreDelegation: grid tests with SpyBlobStore/SpyVectorStore injected into both memory and sqlite backends — verify append→put, iterate→get, and search delegation through the pluggable store ABCs * memory2: add R*Tree spatial index for NearFilter SQL pushdown, add e2e tests R*Tree virtual tables enable O(log n) pose-based proximity queries instead of full-table Python scans. E2E tests verify import pipeline and read-only queries against real robot sensor data (video + lidar). * auto index tags * memory/stream str, and observables * live stream is a resource * readme work * streams and intro * renamed readme to arch * Rename memory2 → memory, fix all imports and type errors - Replace all dimos.memory2 imports with dimos.memory - Make concrete filter classes inherit from Filter ABC - Fix mypy errors: type narrowing, Optional guards, annotation mismatches - Fix test_impl.py: filter_tags() → tags() - Remove intro.py (superseded by intro.md) - Delete old dimos/memory2/ directory * Revert memory rename: restore memory/ from dev, new code lives in memory2/ - Restore dimos/memory/ (old timeseries memory) to match dev - Move new memory system back to dimos/memory2/ with corrected imports - Delete dimos/memory_old/ (no longer needed) - Fix memory_old imports in tf.py, timestamped.py, replay.py → dimos.memory - Remove dps CLI util and pyproject entry - Remove unitree_go2_memory blueprint (depends on deleted modules) * Remove stray old memory module references - Delete empty dimos/memory/impl/sqlite.py - Remove nonexistent memory-module entry from all_blueprints - Restore codeblocks.md from dev * Remove LFS test databases from PR These were added during development but shouldn't be in the PR. * Address review findings: SQL injection guards, type fixes, cleanup - Remove dead dict(hits) and thread-affinity assertion in SqliteBackend - Validate order_field and tag keys against _IDENT_RE to prevent SQL injection - Replace assert bs is not None with RuntimeError for -O safety - Add hash=False to NearFilter.pose, TagsFilter.tags, PredicateFilter.fn - Collapse CaptionDetail enum to 3 distinct levels (BRIEF/NORMAL/DETAILED) - Fix Stream.map() return type: Stream[Any] → Stream[R] - Update architecture.md: SqliteBackend status Stub → Complete - Document SqliteBlobStore commit responsibility - Guard ImageDetections.ts against image=None * Revert detection type changes: keep image as required field Restores detection2d/bbox.py, imageDetections.py, and utils.py to dev state — the image-optional decoupling is not needed for memory2. * add libturbojpeg to docker image * Make turbojpeg import lazy so tests skip gracefully in CI Move top-level turbojpeg import in Image.py to the two methods that use it, and guard jpeg codec tests behind ImportError / importorskip so the test suite passes when libturbojpeg is not installed. * Give each SqliteBackend its own connection for WAL-mode concurrency Previously all backends shared a single sqlite3.Connection — concurrent writes from different streams could interleave commits/rollbacks. Now SqliteSession opens a dedicated connection per backend, with per-backend blob/vector stores wrapping the same connection for atomicity. A separate registry connection handles the _streams table. Also makes SqliteBackend a CompositeResource so session.own(backend) properly closes connections on stop, and fixes live iterator cleanup in both backends (backfill phase now inside try/finally). * Block search_text on SqliteBackend to prevent full table scans search_text previously loaded every blob from the DB and did Python substring matching — a silent full table scan. Raise NotImplementedError instead until proper SQL pushdown is implemented. * Catch RuntimeError from missing turbojpeg native library in codec tests TurboJPEG import succeeds but instantiation raises RuntimeError when the native library isn't installed. Skip the test case gracefully. * pr comments * occupancy change undo * tests cleanup * compression codec added, new bigoffice db uploaded * correct jpeg codec * PR comments cleanup * blobstore stream -> stream_name * vectorstore stream -> stream_name * resource typing fixes * move type definitions into dimos/memory2/type/ subpackage Separate pure-definition files (protocols, ABCs, dataclasses) from implementation files by moving them into a type/ subpackage: - backend.py → type/backend.py - type.py → type/observation.py - filter.py → type/filter.py Added type/__init__.py with re-exports for convenience imports. Updated all 24 importing files across the module. * lz4 codec included, utils/ cleanup * migrated stores to a new config system * config fix * rewrite * update memory2 docs to reflect new architecture - Remove Session layer references (Store → Stream directly) - Backend → Index protocol, concrete Backend composite - SessionConfig/BackendConfig → StoreConfig - ListBackend/SqliteBackend → ListIndex/SqliteIndex - Updated impl README with new 'writing a new index' guide - Verified intro.md code blocks via md-babel-py * rename LiveChannel → Notifier, SubjectChannel → SubjectNotifier Clearer name for the push-notification ABC — "Notifier" directly conveys its subscribe/notify role without leaking the "live" stream concept into a lower layer. * rename Index → MetadataStore, drop Backend property boilerplate, simplify Store.stream() - Index → MetadataStore, ListIndex → ListMetadataStore, SqliteIndex → SqliteMetadataStore Consistent naming with BlobStore/VectorStore. Backend composition reads: MetadataStore + BlobStore + VectorStore + Notifier - Backend: replace _private + @property accessors with plain public attributes - Store.stream(): use model_dump(exclude_none=True) instead of manual dict filtering * rename MetadataStore → ObservationStore Better name — describes what it stores, not the kind of data. Parallels BlobStore/VectorStore naturally. * self-contained SQLite components with dual-mode constructors (conn/path) Move table DDL into SqliteObservationStore.__init__ so all three SQLite components (ObservationStore, BlobStore, VectorStore) are self-contained and can be used standalone with path= without needing a full Store. - Extract open_sqlite_connection utility from SqliteStore._open_connection - Add path= keyword to SqliteBlobStore, SqliteVectorStore, SqliteObservationStore - Promote BlobStore/VectorStore base classes to CompositeResource for clean connection ownership via register_disposables - SqliteStore now closes backend_conn directly instead of via metadata_store.stop() - Add standalone component tests verifying path= mode works without Store * move ObservationStore classes into observationstore/ directory Matches the existing pattern of blobstore/ and vectorstore/ having their own directories. SqliteObservationStore + helpers moved from impl/sqlite.py, ListObservationStore moved from impl/memory.py. impl/ files now import from the new location. * add RegistryStore to persist fully-resolved backend config per stream The old _streams table only stored (name, payload_module, codec_id), so stream overrides (blob_store, vector_store, eager_blobs, page_size, etc.) were lost on reopen. RegistryStore stores the complete serialized config as JSON, enabling _create_backend to reconstruct any stream identically. Each component (SqliteBlobStore, FileBlobStore, SqliteVectorStore, SqliteObservationStore, SubjectNotifier) gets a pydantic Config class and serialize/deserialize methods. Backend.serialize() orchestrates the sub-stores. SqliteStore splits _create_backend into a create path (live objects) and a load path (deserialized config). Includes automatic migration from the legacy three-column schema. * move ABCs from type/backend.py into their own dirs, rename livechannel → notifier Each abstract base class now lives as base.py in its implementation directory: blobstore/base.py, vectorstore/base.py, observationstore/base.py, notifier/base.py. type/backend.py is deleted. livechannel/ is renamed to notifier/ with a backwards-compat shim so old serialized registry entries still resolve via importlib. * move serialize() to base classes, drop deserialize() in favor of constructor serialize() is now a concrete method on BlobStore, VectorStore, and Notifier base classes — implementations inherit it via self._config.model_dump(). deserialize() classmethods are removed entirely; deserialize_component() in registry.py calls cls(**config) directly. Backend.deserialize() is also removed (unused — _assemble_backend handles reconstruction). * move _create_backend to Store base, MemoryStore becomes empty subclass Store._create_backend is now concrete — resolves codec, instantiates components (class → instance or uses instance directly), builds Backend. StoreConfig holds typed component fields (class or instance) with in-memory defaults. codec removed from StoreConfig (per-stream concern, not store-level). MemoryStore is now just `pass` — inherits everything from Store. SqliteStore overrides _create_backend to inject conn-shared components and registry persistence, then delegates to super(). * move connection init from __init__ to start(), make ObservationStore a Resource SQLite components (BlobStore, VectorStore, ObservationStore) now defer connection opening and table creation to start(). __init__ stores config only. Store._create_backend and SqliteStore._create_backend call start() on all components they instantiate. ObservationStore converted from Protocol to CompositeResource base class so all observation stores inherit start()/stop() lifecycle. * rename impl/ → store/, move store.py → store/base.py All store-related code now lives under store/: base class in base.py, MemoryStore in memory.py, SqliteStore in sqlite.py. store/__init__.py re-exports public API. Also renamed test_impl.py → test_store.py. * remove section separator comments from memory2/ * remove __init__.py re-exports, use direct module imports Subdirectory __init__.py files in memory2/ were re-exporting symbols from their submodules. Replace all imports with direct module paths (e.g. utils.sqlite.open_sqlite_connection instead of utils) and empty out the __init__.py files. * delete livechannel/ backwards-compat shim * simplify RegistryStore: drop legacy schema migration Replace _migrate_or_create with CREATE TABLE IF NOT EXISTS. * use context managers in standalone component tests Replace start()/try/finally/stop() with `with` statements. * delete all __init__.py files from memory2/ No code imports from package-level; all use direct module paths. Python 3.3+ implicit namespace packages make these unnecessary. * make all memory2 sub-store components Configurable Migrate BlobStore, VectorStore, ObservationStore, Notifier, and RegistryStore to use the Configurable[ConfigT] mixin pattern, matching the existing Store class. Runtime deps (conn, codec) use Field(exclude=True) so serialize()/model_dump() skips them. All call sites updated to keyword args. * add open_disposable_sqlite_connection and use it everywhere Centralizes the pattern of opening a SQLite connection paired with a disposable that closes it, replacing manual Disposable(lambda: conn.close()) at each call site. * add StreamAccessor for attribute-style stream access on Store * small cleanups: BlobStore.delete raises KeyError on missing, drop _MISSING sentinel * checkout mapping/occupancy/gradient.py from dev * limit opencv threads to 2 by default, checkout worker.py from dev * test for magic accessor * ci/pr comments * widen flaky pointcloud AABB tolerance from 0.1 to 0.2 The test_detection3dpc test fails intermittently in full suite runs due to non-deterministic point cloud boundary values. * suppress mypy false positive on scipy distance_transform_edt return type * ci test fixes * sam mini PR comments * replace Generator[T, None, None] with Iterator[T] in memory2 tests * fix missing TypeVar import in subject.py * skipping turbojpeg stuff in CI * removed db from lfs for now * turbojpeg --- dimos/core/library_config.py | 27 + dimos/core/resource.py | 45 +- dimos/core/worker.py | 2 + dimos/mapping/occupancy/gradient.py | 2 +- dimos/memory2/architecture.md | 114 +++ dimos/memory2/backend.py | 244 ++++++ dimos/memory2/blobstore/base.py | 58 ++ dimos/memory2/blobstore/blobstore.md | 84 ++ dimos/memory2/blobstore/file.py | 70 ++ dimos/memory2/blobstore/sqlite.py | 108 +++ dimos/memory2/blobstore/test_blobstore.py | 62 ++ dimos/memory2/buffer.py | 248 ++++++ dimos/memory2/codecs/README.md | 57 ++ dimos/memory2/codecs/base.py | 112 +++ dimos/memory2/codecs/jpeg.py | 39 + dimos/memory2/codecs/lcm.py | 33 + dimos/memory2/codecs/lz4.py | 42 + dimos/memory2/codecs/pickle.py | 28 + dimos/memory2/codecs/test_codecs.py | 185 +++++ dimos/memory2/conftest.py | 89 +++ dimos/memory2/embed.py | 79 ++ dimos/memory2/embeddings.md | 148 ++++ dimos/memory2/intro.md | 170 ++++ dimos/memory2/notes.md | 10 + dimos/memory2/notifier/base.py | 62 ++ dimos/memory2/notifier/subject.py | 70 ++ dimos/memory2/observationstore/base.py | 73 ++ dimos/memory2/observationstore/memory.py | 80 ++ dimos/memory2/observationstore/sqlite.py | 444 +++++++++++ dimos/memory2/registry.py | 81 ++ dimos/memory2/store/README.md | 130 ++++ dimos/memory2/store/base.py | 166 ++++ dimos/memory2/store/memory.py | 21 + dimos/memory2/store/sqlite.py | 217 ++++++ dimos/memory2/stream.py | 363 +++++++++ dimos/memory2/streaming.md | 109 +++ dimos/memory2/test_blobstore_integration.py | 161 ++++ dimos/memory2/test_buffer.py | 86 +++ dimos/memory2/test_e2e.py | 256 ++++++ dimos/memory2/test_e2e_processing.py | 16 + dimos/memory2/test_embedding.py | 396 ++++++++++ dimos/memory2/test_registry.py | 263 +++++++ dimos/memory2/test_save.py | 123 +++ dimos/memory2/test_store.py | 527 +++++++++++++ dimos/memory2/test_stream.py | 728 ++++++++++++++++++ dimos/memory2/transform.py | 115 +++ dimos/memory2/type/filter.py | 212 +++++ dimos/memory2/type/observation.py | 112 +++ dimos/memory2/utils/formatting.py | 58 ++ dimos/memory2/utils/sqlite.py | 43 ++ dimos/memory2/utils/validation.py | 25 + dimos/memory2/vectorstore/base.py | 65 ++ dimos/memory2/vectorstore/memory.py | 61 ++ dimos/memory2/vectorstore/sqlite.py | 103 +++ dimos/models/embedding/clip.py | 6 +- dimos/models/vl/florence.py | 50 +- dimos/msgs/sensor_msgs/Image.py | 27 +- .../type/detection3d/test_pointcloud.py | 24 +- dimos/utils/docs/doclinks.py | 23 +- dimos/utils/threadpool.py | 2 +- docker/python/Dockerfile | 3 +- docs/usage/transports/index.md | 2 +- pyproject.toml | 5 +- uv.lock | 135 ++++ 64 files changed, 7436 insertions(+), 63 deletions(-) create mode 100644 dimos/core/library_config.py create mode 100644 dimos/memory2/architecture.md create mode 100644 dimos/memory2/backend.py create mode 100644 dimos/memory2/blobstore/base.py create mode 100644 dimos/memory2/blobstore/blobstore.md create mode 100644 dimos/memory2/blobstore/file.py create mode 100644 dimos/memory2/blobstore/sqlite.py create mode 100644 dimos/memory2/blobstore/test_blobstore.py create mode 100644 dimos/memory2/buffer.py create mode 100644 dimos/memory2/codecs/README.md create mode 100644 dimos/memory2/codecs/base.py create mode 100644 dimos/memory2/codecs/jpeg.py create mode 100644 dimos/memory2/codecs/lcm.py create mode 100644 dimos/memory2/codecs/lz4.py create mode 100644 dimos/memory2/codecs/pickle.py create mode 100644 dimos/memory2/codecs/test_codecs.py create mode 100644 dimos/memory2/conftest.py create mode 100644 dimos/memory2/embed.py create mode 100644 dimos/memory2/embeddings.md create mode 100644 dimos/memory2/intro.md create mode 100644 dimos/memory2/notes.md create mode 100644 dimos/memory2/notifier/base.py create mode 100644 dimos/memory2/notifier/subject.py create mode 100644 dimos/memory2/observationstore/base.py create mode 100644 dimos/memory2/observationstore/memory.py create mode 100644 dimos/memory2/observationstore/sqlite.py create mode 100644 dimos/memory2/registry.py create mode 100644 dimos/memory2/store/README.md create mode 100644 dimos/memory2/store/base.py create mode 100644 dimos/memory2/store/memory.py create mode 100644 dimos/memory2/store/sqlite.py create mode 100644 dimos/memory2/stream.py create mode 100644 dimos/memory2/streaming.md create mode 100644 dimos/memory2/test_blobstore_integration.py create mode 100644 dimos/memory2/test_buffer.py create mode 100644 dimos/memory2/test_e2e.py create mode 100644 dimos/memory2/test_e2e_processing.py create mode 100644 dimos/memory2/test_embedding.py create mode 100644 dimos/memory2/test_registry.py create mode 100644 dimos/memory2/test_save.py create mode 100644 dimos/memory2/test_store.py create mode 100644 dimos/memory2/test_stream.py create mode 100644 dimos/memory2/transform.py create mode 100644 dimos/memory2/type/filter.py create mode 100644 dimos/memory2/type/observation.py create mode 100644 dimos/memory2/utils/formatting.py create mode 100644 dimos/memory2/utils/sqlite.py create mode 100644 dimos/memory2/utils/validation.py create mode 100644 dimos/memory2/vectorstore/base.py create mode 100644 dimos/memory2/vectorstore/memory.py create mode 100644 dimos/memory2/vectorstore/sqlite.py diff --git a/dimos/core/library_config.py b/dimos/core/library_config.py new file mode 100644 index 0000000000..813fb642f6 --- /dev/null +++ b/dimos/core/library_config.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Process-wide library defaults. +# Modules that need different settings can override in their own start(). + + +def apply_library_config() -> None: + """Apply process-wide library defaults. Call once per process.""" + # Limit OpenCV internal threads to avoid idle thread contention. + try: + import cv2 + + cv2.setNumThreads(2) + except ImportError: + pass diff --git a/dimos/core/resource.py b/dimos/core/resource.py index ce3f735329..63b1eec4f0 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from __future__ import annotations +from abc import abstractmethod +from typing import TYPE_CHECKING, Self -class Resource(ABC): +if TYPE_CHECKING: + from types import TracebackType + +from reactivex.abc import DisposableBase +from reactivex.disposable import CompositeDisposable + + +class Resource(DisposableBase): @abstractmethod def start(self) -> None: ... @@ -43,3 +52,35 @@ def dispose(self) -> None: """ self.stop() + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: + self.stop() + + +class CompositeResource(Resource): + """Resource that owns child disposables, disposed on stop().""" + + _disposables: CompositeDisposable + + def __init__(self) -> None: + self._disposables = CompositeDisposable() + + def register_disposables(self, *disposables: DisposableBase) -> None: + """Register child disposables to be disposed when this resource stops.""" + for d in disposables: + self._disposables.add(d) + + def start(self) -> None: + pass + + def stop(self) -> None: + self._disposables.dispose() diff --git a/dimos/core/worker.py b/dimos/core/worker.py index dca561f16c..8f3beee7ec 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.library_config import apply_library_config from dimos.utils.logging_config import setup_logger from dimos.utils.sequential_ids import SequentialIds @@ -292,6 +293,7 @@ def _suppress_console_output() -> None: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: + apply_library_config() instances: dict[int, Any] = {} try: diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index 880f2692da..c9db43088e 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -53,7 +53,7 @@ def gradient( distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) # Convert to meters and clip to max distance - distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) + distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) # type: ignore[operator] # Invert and scale to 0-100 range # Far from obstacles (max_distance) -> 0 diff --git a/dimos/memory2/architecture.md b/dimos/memory2/architecture.md new file mode 100644 index 0000000000..9dc805577f --- /dev/null +++ b/dimos/memory2/architecture.md @@ -0,0 +1,114 @@ +# memory + +Observation storage and streaming layer for DimOS. Pull-based, lazy, composable. + +## Architecture + +``` + Live Sensor Data + ↓ +Store → Stream → [filters / transforms / terminals] → Stream → [filters / transforms / terminals] → Stream → Live hooks + ↓ ↓ ↓ + Backend (ObservationStore + BlobStore + VectorStore + Notifier) Backend In Memory +``` + +**Store** owns a storage location (file, in-memory) and directly manages named streams. **Stream** is the query/iteration surface — lazy until a terminal is called. **Backend** is a concrete composite that orchestrates ObservationStore + BlobStore + VectorStore + Notifier for each stream. + +Supporting Systems: + +- BlobStore — separates large payloads from metadata. FileBlobStore (files on disk) and SqliteBlobStore (blob table per stream). Supports lazy loading. +- Codecs — codec_for() auto-selects: JpegCodec for images (TurboJPEG, ~10-20x compression), LcmCodec for DimOS messages, PickleCodec fallback. +- Transformers — Transformer[T,R] ABC wrapping iterator-to-iterator. EmbedImages/EmbedText enrich observations with embeddings. QualityWindow keeps best per time window. +- Backpressure Buffers — KeepLast, Bounded, DropNew, Unbounded — bridge push/pull for live mode. + + +## Modules + +| Module | What | +|----------------|-------------------------------------------------------------------| +| `stream.py` | Stream node — filters, transforms, terminals | +| `backend.py` | Concrete Backend composite (ObservationStore + Blob + Vector + Live) | +| `store.py` | Store, StoreConfig | +| `transform.py` | Transformer ABC, FnTransformer, FnIterTransformer, QualityWindow | +| `buffer.py` | Backpressure buffers for live mode (KeepLast, Bounded, Unbounded) | +| `embed.py` | EmbedImages / EmbedText transformers | + +## Subpackages + +| Package | What | Docs | +|-----------------|------------------------------------------------------|--------------------------------------------------| +| `type/` | Observation, EmbeddedObservation, Filter/StreamQuery | | +| `store/` | Store ABC + implementations (MemoryStore, SqliteStore) | [store/README.md](store/README.md) | +| `notifier/` | Notifier ABC + SubjectNotifier | | +| `blobstore/` | BlobStore ABC + implementations (file, sqlite) | [blobstore/blobstore.md](blobstore/blobstore.md) | +| `codecs/` | Encode/decode for storage (pickle, JPEG, LCM) | [codecs/README.md](codecs/README.md) | +| `vectorstore/` | VectorStore ABC + implementations (memory, sqlite) | | +| `observationstore/` | ObservationStore Protocol + implementations | | + +## Docs + +| Doc | What | +|-----|------| +| [streaming.md](streaming.md) | Lazy vs materializing vs terminal — evaluation model, live safety | +| [embeddings.md](embeddings.md) | Embedding layer design — EmbeddedObservation, vector search, EmbedImages/EmbedText | +| [blobstore/blobstore.md](blobstore/blobstore.md) | BlobStore architecture — separate payload storage from metadata | + +## Query execution + +`StreamQuery` holds the full query spec (filters, text search, vector search, ordering, offset/limit). It also provides `apply(iterator)` — a Python-side execution path that runs all operations as in-memory predicates, brute-force cosine, and list sorts. + +This is the **default fallback**. ObservationStore implementations are free to push down operations using store-specific strategies instead: + +| Operation | Python fallback (`StreamQuery.apply`) | Store push-down (example) | +|----------------|---------------------------------------|----------------------------------| +| Filters | `filter.matches()` predicates | SQL WHERE clauses | +| Text search | Case-insensitive substring | FTS5 full-text index | +| Vector search | Brute-force cosine similarity | vec0 / FAISS ANN index | +| Ordering | `sorted()` materialization | SQL ORDER BY | +| Offset / limit | `islice()` | SQL OFFSET / LIMIT | + +`ListObservationStore` delegates entirely to `StreamQuery.apply()`. `SqliteObservationStore` translates the query into SQL and only falls back to Python for operations it can't express natively. + +Transform-sourced streams (post `.transform()`) always use `StreamQuery.apply()` since there's no index to push down to. + +## Quick start + +```python +from dimos.memory2 import MemoryStore + +store = MemoryStore() +images = store.stream("images") + +# Write +images.append(frame, ts=time.time(), pose=(x, y, z), tags={"camera": "front"}) + +# Query +recent = images.after(t).limit(10).fetch() +nearest = images.near(pose, radius=2.0).fetch() +latest = images.last() + +# Transform (class or bare generator function) +edges = images.transform(Canny()).save(store.stream("edges")) + +def running_avg(upstream): + total, n = 0.0, 0 + for obs in upstream: + total += obs.data; n += 1 + yield obs.derive(data=total / n) +avgs = stream.transform(running_avg).fetch() + +# Live +for obs in images.live().transform(process): + handle(obs) + +# Embed + search +images.transform(EmbedImages(clip)).save(store.stream("embedded")) +results = store.stream("embedded").search(query_vec, k=5).fetch() +``` + +## Implementations + +| ObservationStore | Status | Storage | +|-----------------|----------|----------------------------------------| +| `ListObservationStore` | Complete | In-memory (lists + brute-force search) | +| `SqliteObservationStore` | Complete | SQLite (WAL, FTS5, vec0) | diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py new file mode 100644 index 0000000000..c861993de9 --- /dev/null +++ b/dimos/memory2/backend.py @@ -0,0 +1,244 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Concrete composite Backend that orchestrates ObservationStore + BlobStore + VectorStore + Notifier.""" + +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.codecs.base import Codec, codec_id +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.type.observation import _UNLOADED + +if TYPE_CHECKING: + from collections.abc import Iterator + + from reactivex.abc import DisposableBase + + from dimos.memory2.blobstore.base import BlobStore + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.notifier.base import Notifier + from dimos.memory2.observationstore.base import ObservationStore + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + from dimos.memory2.vectorstore.base import VectorStore + +T = TypeVar("T") + + +class Backend(Generic[T]): + """Orchestrates metadata, blob, vector, and live stores for one stream. + + This is a concrete class — NOT a protocol. All shared orchestration logic + (encode → insert → store blob → index vector → notify) lives here, + eliminating duplication between ListObservationStore and SqliteObservationStore. + """ + + def __init__( + self, + *, + metadata_store: ObservationStore[T], + codec: Codec[Any], + blob_store: BlobStore | None = None, + vector_store: VectorStore | None = None, + notifier: Notifier[T] | None = None, + eager_blobs: bool = False, + ) -> None: + self.metadata_store = metadata_store + self.codec = codec + self.blob_store = blob_store + self.vector_store = vector_store + self.notifier: Notifier[T] = notifier or SubjectNotifier() + self.eager_blobs = eager_blobs + + @property + def name(self) -> str: + return self.metadata_store.name + + def _make_loader(self, row_id: int) -> Any: + bs = self.blob_store + if bs is None: + raise RuntimeError("BlobStore required but not configured") + name, codec = self.name, self.codec + + def loader() -> Any: + raw = bs.get(name, row_id) + return codec.decode(raw) + + return loader + + def append(self, obs: Observation[T]) -> Observation[T]: + # Encode payload before any locking (avoids holding locks during IO) + encoded: bytes | None = None + if self.blob_store is not None: + encoded = self.codec.encode(obs._data) + + try: + # Insert metadata, get assigned id + row_id = self.metadata_store.insert(obs) + obs.id = row_id + + # Store blob + if encoded is not None: + assert self.blob_store is not None + self.blob_store.put(self.name, row_id, encoded) + # Replace inline data with lazy loader + obs._data = _UNLOADED # type: ignore[assignment] + obs._loader = self._make_loader(row_id) + + # Store embedding vector + if self.vector_store is not None: + emb = getattr(obs, "embedding", None) + if emb is not None: + self.vector_store.put(self.name, row_id, emb) + + # Commit if the metadata store supports it (e.g. SqliteObservationStore) + if hasattr(self.metadata_store, "commit"): + self.metadata_store.commit() + except BaseException: + if hasattr(self.metadata_store, "rollback"): + self.metadata_store.rollback() + raise + + self.notifier.notify(obs) + return obs + + def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and query.live_buffer is not None: + raise TypeError("Cannot combine .search() with .live() — search is a batch operation.") + buf = query.live_buffer + if buf is not None: + sub = self.notifier.subscribe(buf) + return self._iterate_live(query, buf, sub) + return self._iterate_snapshot(query) + + def _attach_loaders(self, it: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + """Attach lazy blob loaders to observations from the metadata store.""" + if self.blob_store is None: + yield from it + return + for obs in it: + if obs._loader is None and isinstance(obs._data, type(_UNLOADED)): + obs._loader = self._make_loader(obs.id) + yield obs + + def _iterate_snapshot(self, query: StreamQuery) -> Iterator[Observation[T]]: + if query.search_vec is not None and self.vector_store is not None: + yield from self._vector_search(query) + return + + it: Iterator[Observation[T]] = self._attach_loaders(self.metadata_store.query(query)) + + # Apply python post-filters after loaders are attached (so obs.data works) + python_filters = getattr(self.metadata_store, "_pending_python_filters", None) + pending_query = getattr(self.metadata_store, "_pending_query", None) + if python_filters: + from itertools import islice as _islice + + it = (obs for obs in it if all(f.matches(obs) for f in python_filters)) + if pending_query and pending_query.offset_val: + it = _islice(it, pending_query.offset_val, None) + if pending_query and pending_query.limit_val is not None: + it = _islice(it, pending_query.limit_val) + + if self.eager_blobs and self.blob_store is not None: + for obs in it: + _ = obs.data # trigger lazy loader + yield obs + else: + yield from it + + def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: + vs = self.vector_store + assert vs is not None and query.search_vec is not None + + hits = vs.search(self.name, query.search_vec, query.search_k or 10) + if not hits: + return + + ids = [h[0] for h in hits] + obs_list = list(self._attach_loaders(iter(self.metadata_store.fetch_by_ids(ids)))) + obs_by_id = {obs.id: obs for obs in obs_list} + + # Preserve VectorStore ranking order + ranked: list[Observation[T]] = [] + for obs_id, sim in hits: + match = obs_by_id.get(obs_id) + if match is not None: + ranked.append( + match.derive(data=match.data, embedding=query.search_vec, similarity=sim) + ) + + # Apply remaining query ops (skip vector search) + rest = replace(query, search_vec=None, search_k=None) + yield from rest.apply(iter(ranked)) + + def _iterate_live( + self, + query: StreamQuery, + buf: BackpressureBuffer[Observation[T]], + sub: DisposableBase, + ) -> Iterator[Observation[T]]: + from dimos.memory2.buffer import ClosedError + + eager = self.eager_blobs and self.blob_store is not None + + try: + # Backfill phase + last_id = -1 + for obs in self._iterate_snapshot(query): + last_id = max(last_id, obs.id) + yield obs + + # Live tail + filters = query.filters + while True: + obs = buf.take() + if obs.id <= last_id: + continue + last_id = obs.id + if filters and not all(f.matches(obs) for f in filters): + continue + if eager: + _ = obs.data # trigger lazy loader + yield obs + except (ClosedError, StopIteration): + pass + finally: + sub.dispose() + + def count(self, query: StreamQuery) -> int: + if query.search_vec: + return sum(1 for _ in self.iterate(query)) + return self.metadata_store.count(query) + + def serialize(self) -> dict[str, Any]: + """Serialize the fully-resolved backend config to a dict.""" + return { + "codec_id": codec_id(self.codec), + "eager_blobs": self.eager_blobs, + "metadata_store": self.metadata_store.serialize() + if hasattr(self.metadata_store, "serialize") + else None, + "blob_store": self.blob_store.serialize() if self.blob_store else None, + "vector_store": self.vector_store.serialize() if self.vector_store else None, + "notifier": self.notifier.serialize(), + } + + def stop(self) -> None: + """Stop the metadata store (closes per-stream connections if any).""" + if hasattr(self.metadata_store, "stop"): + self.metadata_store.stop() diff --git a/dimos/memory2/blobstore/base.py b/dimos/memory2/blobstore/base.py new file mode 100644 index 0000000000..b146d2028e --- /dev/null +++ b/dimos/memory2/blobstore/base.py @@ -0,0 +1,58 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + + +class BlobStoreConfig(BaseConfig): + pass + + +class BlobStore(Configurable[BlobStoreConfig], CompositeResource): + """Persistent storage for encoded payload blobs. + + Separates payload data from metadata indexing so that large blobs + (images, point clouds) don't penalize metadata queries. + """ + + default_config: type[BlobStoreConfig] = BlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @abstractmethod + def put(self, stream_name: str, key: int, data: bytes) -> None: + """Store a blob for the given stream and observation id.""" + ... + + @abstractmethod + def get(self, stream_name: str, key: int) -> bytes: + """Retrieve a blob by stream name and observation id.""" + ... + + @abstractmethod + def delete(self, stream_name: str, key: int) -> None: + """Delete a blob by stream name and observation id.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/blobstore/blobstore.md b/dimos/memory2/blobstore/blobstore.md new file mode 100644 index 0000000000..00006cf468 --- /dev/null +++ b/dimos/memory2/blobstore/blobstore.md @@ -0,0 +1,84 @@ +# blobstore/ + +Separates payload blob storage from metadata indexing. Observation payloads vary hugely in size — a `Vector3` is 24 bytes, a camera frame is megabytes. Storing everything inline penalizes metadata queries. BlobStore lets large payloads live elsewhere. + +## ABC (`blobstore/base.py`) + +```python +class BlobStore(Resource): + def put(self, stream_name: str, key: int, data: bytes) -> None: ... + def get(self, stream_name: str, key: int) -> bytes: ... # raises KeyError if missing + def delete(self, stream_name: str, key: int) -> None: ... # silent if missing +``` + +- `stream_name` — stream name (used to organize storage: directories, tables) +- `key` — observation id +- `data` — encoded payload bytes (codec handles serialization, blob store handles persistence) +- Extends `Resource` (start/stop) but does NOT own its dependencies' lifecycle + +## Implementations + +### `file.py` — FileBlobStore + +Stores blobs as files on disk, one directory per stream. + +``` +{root}/{stream}/{key}.bin +``` + +`__init__(root: str | os.PathLike[str])` — `start()` creates the root directory. + +### `sqlite.py` — SqliteBlobStore + +Stores blobs in a separate SQLite table per stream. + +```sql +CREATE TABLE "{stream}_blob" (id INTEGER PRIMARY KEY, data BLOB NOT NULL) +``` + +`__init__(conn: sqlite3.Connection)` — does NOT own the connection. + +**Internal use** (same db as metadata): `SqliteStore._create_backend()` creates one connection per stream, passes it to both the index and the blob store. + +**External use** (separate db): user creates a separate connection and passes it. User manages that connection's lifecycle. + +**JOIN optimization**: when `eager_blobs=True` and the blob store shares the same connection as the index, `SqliteObservationStore` can optimize with a JOIN instead of separate queries: + +```sql +SELECT m.id, m.ts, m.pose, m.tags, b.data +FROM "images" m JOIN "images_blob" b ON m.id = b.id +WHERE m.ts > ? +``` + +## Lazy loading + +`eager_blobs` is a store/stream-level flag, orthogonal to blob store choice. It controls WHEN data is loaded: + +- `eager_blobs=False` (default) → backend sets `Observation._loader`, payload loaded on `.data` access +- `eager_blobs=True` → backend triggers `.data` access during iteration (eager) + +| eager_blobs | blob store | loading strategy | +|-------------|-----------|-----------------| +| True | SqliteBlobStore (same conn) | JOIN — one round trip | +| True | any other | iterate meta, `blob_store.get()` per row | +| False | any | iterate meta only, `_loader = lambda: codec.decode(blob_store.get(...))` | + +## Usage + +```python +# Per-stream blob store choice +poses = store.stream("poses", PoseStamped) # default, lazy +images = store.stream("images", Image, eager_blobs=True) # eager +images = store.stream("images", Image, blob_store=file_blobs) # override +``` + +## Files + +``` +blobstore/ + base.py BlobStore ABC + blobstore.md this file + __init__.py re-exports BlobStore, FileBlobStore, SqliteBlobStore + file.py FileBlobStore + sqlite.py SqliteBlobStore +``` diff --git a/dimos/memory2/blobstore/file.py b/dimos/memory2/blobstore/file.py new file mode 100644 index 0000000000..e0ae80b61a --- /dev/null +++ b/dimos/memory2/blobstore/file.py @@ -0,0 +1,70 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from dimos.memory2.blobstore.base import BlobStore, BlobStoreConfig +from dimos.memory2.utils.validation import validate_identifier + + +class FileBlobStoreConfig(BlobStoreConfig): + root: str + + +class FileBlobStore(BlobStore): + """Stores blobs as files on disk, one directory per stream. + + Layout:: + + {root}/{stream}/{key}.bin + """ + + default_config = FileBlobStoreConfig + config: FileBlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._root = Path(self.config.root) + + def _path(self, stream_name: str, key: int) -> Path: + validate_identifier(stream_name) + return self._root / stream_name / f"{key}.bin" + + def start(self) -> None: + self._root.mkdir(parents=True, exist_ok=True) + + def stop(self) -> None: + pass + + def put(self, stream_name: str, key: int, data: bytes) -> None: + p = self._path(stream_name, key) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + + def get(self, stream_name: str, key: int) -> bytes: + p = self._path(stream_name, key) + try: + return p.read_bytes() + except FileNotFoundError: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None + + def delete(self, stream_name: str, key: int) -> None: + p = self._path(stream_name, key) + try: + p.unlink() + except FileNotFoundError: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py new file mode 100644 index 0000000000..1cb5f1aa38 --- /dev/null +++ b/dimos/memory2/blobstore/sqlite.py @@ -0,0 +1,108 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from pydantic import Field, model_validator + +from dimos.memory2.blobstore.base import BlobStore, BlobStoreConfig +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier + + +class SqliteBlobStoreConfig(BlobStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteBlobStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteBlobStore(BlobStore): + """Stores blobs in a separate SQLite table per stream. + + Table layout per stream:: + + CREATE TABLE "{stream}_blob" ( + id INTEGER PRIMARY KEY, + data BLOB NOT NULL + ); + + Supports two construction modes: + + - ``SqliteBlobStore(conn=conn)`` — borrows an externally-managed connection. + - ``SqliteBlobStore(path="file.db")`` — opens and owns its own connection. + + Does NOT commit; the caller (typically Backend) is responsible for commits. + """ + + default_config = SqliteBlobStoreConfig + config: SqliteBlobStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._tables: set[str] = set() + + def _ensure_table(self, stream_name: str) -> None: + if stream_name in self._tables: + return + validate_identifier(stream_name) + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{stream_name}_blob" ' + "(id INTEGER PRIMARY KEY, data BLOB NOT NULL)" + ) + self._tables.add(stream_name) + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + + def put(self, stream_name: str, key: int, data: bytes) -> None: + self._ensure_table(stream_name) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream_name}_blob" (id, data) VALUES (?, ?)', + (key, data), + ) + + def get(self, stream_name: str, key: int) -> bytes: + try: + row = self._conn.execute( + f'SELECT data FROM "{stream_name}_blob" WHERE id = ?', (key,) + ).fetchone() + except Exception: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") + if row is None: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") + result: bytes = row[0] + return result + + def delete(self, stream_name: str, key: int) -> None: + try: + cur = self._conn.execute(f'DELETE FROM "{stream_name}_blob" WHERE id = ?', (key,)) + except Exception: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") from None + if cur.rowcount == 0: + raise KeyError(f"No blob for stream={stream_name!r}, key={key}") diff --git a/dimos/memory2/blobstore/test_blobstore.py b/dimos/memory2/blobstore/test_blobstore.py new file mode 100644 index 0000000000..ade6aa4cc6 --- /dev/null +++ b/dimos/memory2/blobstore/test_blobstore.py @@ -0,0 +1,62 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for BlobStore implementations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from dimos.memory2.blobstore.base import BlobStore + + +class TestBlobStore: + def test_put_get_roundtrip(self, blob_store: BlobStore) -> None: + data = b"hello world" + blob_store.put("stream_a", 1, data) + assert blob_store.get("stream_a", 1) == data + + def test_get_missing_raises(self, blob_store: BlobStore) -> None: + with pytest.raises(KeyError): + blob_store.get("nonexistent", 999) + + def test_put_overwrite(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"first") + blob_store.put("s", 1, b"second") + assert blob_store.get("s", 1) == b"second" + + def test_delete(self, blob_store: BlobStore) -> None: + blob_store.put("s", 1, b"data") + blob_store.delete("s", 1) + with pytest.raises(KeyError): + blob_store.get("s", 1) + + def test_delete_missing_raises(self, blob_store: BlobStore) -> None: + with pytest.raises(KeyError): + blob_store.delete("s", 999) + + def test_stream_isolation(self, blob_store: BlobStore) -> None: + blob_store.put("a", 1, b"alpha") + blob_store.put("b", 1, b"beta") + assert blob_store.get("a", 1) == b"alpha" + assert blob_store.get("b", 1) == b"beta" + + def test_large_blob(self, blob_store: BlobStore) -> None: + data = bytes(range(256)) * 1000 # 256 KB + blob_store.put("big", 0, data) + assert blob_store.get("big", 0) == data + assert blob_store.get("big", 0) == data diff --git a/dimos/memory2/buffer.py b/dimos/memory2/buffer.py new file mode 100644 index 0000000000..49814eb6dc --- /dev/null +++ b/dimos/memory2/buffer.py @@ -0,0 +1,248 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backpressure buffers — the bridge between push and pull. + +Real-world data sources (cameras, LiDAR, ROS topics) and ReactiveX pipelines +are *push-based*: they emit items whenever they please. Databases, analysis +systems, and our memory store are *pull-based*: consumers iterate at their own +pace. A BackpressureBuffer sits between the two, absorbing push bursts so +that the pull side can drain items on its own schedule. + +The choice of strategy controls what happens under load: + +- **KeepLast** — single-slot, always overwrites; best for real-time sensor + data where only the latest reading matters. +- **Bounded** — FIFO with a cap; drops the oldest item on overflow. +- **DropNew** — FIFO with a cap; rejects new items on overflow. +- **Unbounded** — unlimited FIFO; guarantees delivery at the cost of memory. + +All four share the same ABC interface and are interchangeable wherever a +buffer is accepted (e.g. ``Stream.live(buffer=...)``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import deque +import threading +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + +T = TypeVar("T") + + +class ClosedError(Exception): + """Raised when take() is called on a closed buffer.""" + + +class BackpressureBuffer(ABC, Generic[T]): + """Thread-safe buffer between push producers and pull consumers.""" + + @abstractmethod + def put(self, item: T) -> bool: + """Push an item. Returns False if the item was dropped.""" + + @abstractmethod + def take(self, timeout: float | None = None) -> T: + """Block until an item is available. Raises ClosedError if the buffer is closed.""" + + @abstractmethod + def try_take(self) -> T | None: + """Non-blocking take. Returns None if empty.""" + + @abstractmethod + def close(self) -> None: + """Signal no more items. Subsequent take() raises ClosedError.""" + + @abstractmethod + def __len__(self) -> int: ... + + def __iter__(self) -> Iterator[T]: + """Yield items until the buffer is closed.""" + while True: + try: + yield self.take() + except ClosedError: + return + + +class KeepLast(BackpressureBuffer[T]): + """Single-slot buffer. put() always overwrites. Default for live mode.""" + + def __init__(self) -> None: + self._item: T | None = None + self._has_item = False + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._item = item + self._has_item = True + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._has_item: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + item = self._item + assert item is not None + self._item = None + self._has_item = False + return item + + def try_take(self) -> T | None: + with self._cond: + if not self._has_item: + return None + item = self._item + self._item = None + self._has_item = False + return item + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return 1 if self._has_item else 0 + + +class Bounded(BackpressureBuffer[T]): + """FIFO queue with max size. Drops oldest when full.""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque(maxlen=maxlen) + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) # deque(maxlen) drops oldest automatically + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class DropNew(BackpressureBuffer[T]): + """FIFO queue. Rejects new items when full (put returns False).""" + + def __init__(self, maxlen: int) -> None: + self._buf: deque[T] = deque() + self._maxlen = maxlen + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed or len(self._buf) >= self._maxlen: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) + + +class Unbounded(BackpressureBuffer[T]): + """Unbounded FIFO queue. Use carefully — can grow without limit.""" + + def __init__(self) -> None: + self._buf: deque[T] = deque() + self._closed = False + self._cond = threading.Condition() + + def put(self, item: T) -> bool: + with self._cond: + if self._closed: + return False + self._buf.append(item) + self._cond.notify() + return True + + def take(self, timeout: float | None = None) -> T: + with self._cond: + while not self._buf: + if self._closed: + raise ClosedError("Buffer is closed") + if not self._cond.wait(timeout): + raise TimeoutError("take() timed out") + return self._buf.popleft() + + def try_take(self) -> T | None: + with self._cond: + return self._buf.popleft() if self._buf else None + + def close(self) -> None: + with self._cond: + self._closed = True + self._cond.notify_all() + + def __len__(self) -> int: + with self._cond: + return len(self._buf) diff --git a/dimos/memory2/codecs/README.md b/dimos/memory2/codecs/README.md new file mode 100644 index 0000000000..8ad40e95fd --- /dev/null +++ b/dimos/memory2/codecs/README.md @@ -0,0 +1,57 @@ +# codecs + +Encode/decode payloads for persistent storage. Codecs convert typed Python objects to `bytes` and back, used by backends that store observation data as blobs. + +## Protocol + +```python +class Codec(Protocol[T]): + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... +``` + +## Built-in codecs + +| Codec | Type | Notes | +|-------|------|-------| +| `PickleCodec` | Any Python object | Fallback. Uses `HIGHEST_PROTOCOL`. | +| `JpegCodec` | `Image` | Lossy compression via TurboJPEG. ~10-20x smaller. Preserves `frame_id` in header. | +| `LcmCodec` | `DimosMsg` subclasses | Uses `lcm_encode()`/`lcm_decode()`. Zero-copy for LCM message types. | + +## Auto-selection + +`codec_for(payload_type)` picks the right codec: + +```python +from dimos.memory2.codecs import codec_for + +codec_for(Image) # → JpegCodec(quality=50) +codec_for(SomeLcmMsg) # → LcmCodec(SomeLcmMsg) (if has lcm_encode/lcm_decode) +codec_for(dict) # → PickleCodec() (fallback) +codec_for(None) # → PickleCodec() +``` + +## Writing a new codec + +1. Create `dimos/memory/codecs/mycodec.py`: + +```python +class MyCodec: + def encode(self, value: MyType) -> bytes: + ... + + def decode(self, data: bytes) -> MyType: + ... +``` + +2. Add a branch in `codec_for()` in `base.py` to auto-select it for the relevant type. + +3. Add a test case to `test_codecs.py` — the grid fixture makes this easy: + +```python +@pytest.fixture(params=[..., ("mycodec", MyCodec(), sample_value)]) +def codec_case(request): + ... +``` + +No base class needed — `Codec` is a protocol. Just implement `encode` and `decode`. diff --git a/dimos/memory2/codecs/base.py b/dimos/memory2/codecs/base.py new file mode 100644 index 0000000000..821b36b60f --- /dev/null +++ b/dimos/memory2/codecs/base.py @@ -0,0 +1,112 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +from typing import Any, Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") + + +@runtime_checkable +class Codec(Protocol[T]): + """Encode/decode payloads for storage.""" + + def encode(self, value: T) -> bytes: ... + def decode(self, data: bytes) -> T: ... + + +def codec_for(payload_type: type[Any] | None = None) -> Codec[Any]: + """Auto-select codec based on payload type.""" + from dimos.memory2.codecs.pickle import PickleCodec + + if payload_type is not None: + from dimos.msgs.sensor_msgs.Image import Image + + if issubclass(payload_type, Image): + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if hasattr(payload_type, "lcm_encode") and hasattr(payload_type, "lcm_decode"): + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(payload_type) + return PickleCodec() + + +def codec_id(codec: Codec[Any]) -> str: + """Derive a string ID from a codec instance, e.g. ``'lz4+lcm'``. + + Walks the ``_inner`` chain for wrapper codecs, joining with ``+``. + Uses the naming convention ``FooCodec`` → ``'foo'``. + """ + parts: list[str] = [] + c: Any = codec + while hasattr(c, "_inner"): + parts.append(_class_to_id(c)) + c = c._inner + parts.append(_class_to_id(c)) + return "+".join(parts) + + +def codec_from_id(codec_id_str: str, payload_module: str) -> Codec[Any]: + """Reconstruct a codec chain from its string ID (e.g. ``'lz4+lcm'``). + + Builds inside-out: the rightmost segment is the innermost (base) codec. + """ + parts = codec_id_str.split("+") + # Innermost first + result = _make_one(parts[-1], payload_module) + for name in reversed(parts[:-1]): + result = _make_one(name, payload_module, inner=result) + return result + + +def _class_to_id(codec: Any) -> str: + name = type(codec).__name__ + if name.endswith("Codec"): + return name[:-5].lower() + return name.lower() + + +def _resolve_payload_type(payload_module: str) -> type[Any]: + parts = payload_module.rsplit(".", 1) + if len(parts) != 2: + raise ValueError(f"Cannot resolve payload type from {payload_module!r}") + mod = importlib.import_module(parts[0]) + return getattr(mod, parts[1]) # type: ignore[no-any-return] + + +def _make_one(name: str, payload_module: str, inner: Codec[Any] | None = None) -> Codec[Any]: + """Instantiate a single codec by its short name.""" + if name == "lz4": + from dimos.memory2.codecs.lz4 import Lz4Codec + + if inner is None: + raise ValueError("lz4 is a wrapper codec — must have an inner codec") + return Lz4Codec(inner) + if name == "jpeg": + from dimos.memory2.codecs.jpeg import JpegCodec + + return JpegCodec() + if name == "lcm": + from dimos.memory2.codecs.lcm import LcmCodec + + return LcmCodec(_resolve_payload_type(payload_module)) + if name == "pickle": + from dimos.memory2.codecs.pickle import PickleCodec + + return PickleCodec() + raise ValueError(f"Unknown codec: {name!r}") diff --git a/dimos/memory2/codecs/jpeg.py b/dimos/memory2/codecs/jpeg.py new file mode 100644 index 0000000000..3d854400b1 --- /dev/null +++ b/dimos/memory2/codecs/jpeg.py @@ -0,0 +1,39 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs.Image import Image + + +class JpegCodec: + """Codec for Image types — JPEG-compressed inside an LCM Image envelope. + + Uses ``Image.lcm_jpeg_encode/decode`` which preserves ``ts``, ``frame_id``, + and all LCM header fields. Pixel data is lossy-compressed via TurboJPEG. + """ + + def __init__(self, quality: int = 50) -> None: + self._quality = quality + + def encode(self, value: Image) -> bytes: + return value.lcm_jpeg_encode(quality=self._quality) + + def decode(self, data: bytes) -> Image: + from dimos.msgs.sensor_msgs.Image import Image + + return Image.lcm_jpeg_decode(data) diff --git a/dimos/memory2/codecs/lcm.py b/dimos/memory2/codecs/lcm.py new file mode 100644 index 0000000000..fe7055d9c8 --- /dev/null +++ b/dimos/memory2/codecs/lcm.py @@ -0,0 +1,33 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.msgs.protocol import DimosMsg + + +class LcmCodec: + """Codec for DimosMsg types — uses lcm_encode/lcm_decode.""" + + def __init__(self, msg_type: type[DimosMsg]) -> None: + self._msg_type = msg_type + + def encode(self, value: DimosMsg) -> bytes: + return value.lcm_encode() + + def decode(self, data: bytes) -> DimosMsg: + return self._msg_type.lcm_decode(data) diff --git a/dimos/memory2/codecs/lz4.py b/dimos/memory2/codecs/lz4.py new file mode 100644 index 0000000000..15cbad56e4 --- /dev/null +++ b/dimos/memory2/codecs/lz4.py @@ -0,0 +1,42 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import lz4.frame # type: ignore[import-untyped] + +if TYPE_CHECKING: + from dimos.memory2.codecs.base import Codec + + +class Lz4Codec: + """Wraps another codec and applies LZ4 frame compression to the output. + + Works with any inner codec — compresses the bytes produced by + ``inner.encode()`` and decompresses before ``inner.decode()``. + """ + + def __init__(self, inner: Codec[Any], compression_level: int = 0) -> None: + self._inner = inner + self._compression_level = compression_level + + def encode(self, value: Any) -> bytes: + raw = self._inner.encode(value) + return bytes(lz4.frame.compress(raw, compression_level=self._compression_level)) + + def decode(self, data: bytes) -> Any: + raw: bytes = lz4.frame.decompress(data) + return self._inner.decode(raw) diff --git a/dimos/memory2/codecs/pickle.py b/dimos/memory2/codecs/pickle.py new file mode 100644 index 0000000000..7200e1da50 --- /dev/null +++ b/dimos/memory2/codecs/pickle.py @@ -0,0 +1,28 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from typing import Any + + +class PickleCodec: + """Fallback codec for arbitrary Python objects.""" + + def encode(self, value: Any) -> bytes: + return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + + def decode(self, data: bytes) -> Any: + return pickle.loads(data) diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py new file mode 100644 index 0000000000..eece78b1c3 --- /dev/null +++ b/dimos/memory2/codecs/test_codecs.py @@ -0,0 +1,185 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Codec implementations. + +Runs roundtrip encode→decode tests across every codec, verifying data preservation. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.codecs.base import Codec, codec_for +from dimos.memory2.codecs.jpeg import JpegCodec +from dimos.memory2.codecs.lcm import LcmCodec +from dimos.memory2.codecs.pickle import PickleCodec +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.testing.replay import TimedSensorReplay + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.msgs.protocol import DimosMsg + + +@dataclass +class Case: + name: str + codec: Codec[Any] + values: list[Any] + eq: Callable[[Any, Any], bool] | None = None # custom equality: (original, decoded) -> bool + + +def _lcm_values() -> list[DimosMsg]: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + return [ + PoseStamped( + ts=1.0, + frame_id="map", + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ), + PoseStamped(ts=0.5, frame_id="odom"), + ] + + +def _pickle_case() -> Case: + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="pickle", + codec=PickleCodec(), + values=[42, "hello", b"raw bytes", {"key": "value"}], + ) + + +def _lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + return Case( + name="lcm", + codec=LcmCodec(PoseStamped), + values=_lcm_values(), + ) + + +def _lz4_pickle_case() -> Case: + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.memory2.codecs.pickle import PickleCodec + + return Case( + name="lz4+pickle", + codec=Lz4Codec(PickleCodec()), + values=[42, "hello", b"raw bytes", {"key": "value"}, list(range(1000))], + ) + + +def _lz4_lcm_case() -> Case: + from dimos.memory2.codecs.lcm import LcmCodec + from dimos.memory2.codecs.lz4 import Lz4Codec + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + return Case( + name="lz4+lcm", + codec=Lz4Codec(LcmCodec(PoseStamped)), + values=_lcm_values(), + ) + + +def _jpeg_eq(original: Any, decoded: Any) -> bool: + """JPEG is lossy — check shape, frame_id, and pixel closeness.""" + import numpy as np + + if decoded.data.shape != original.data.shape: + return False + if decoded.frame_id != original.frame_id: + return False + return bool(np.mean(np.abs(decoded.data.astype(float) - original.data.astype(float))) < 5) + + +def _jpeg_case() -> Case | None: + try: + from turbojpeg import TurboJPEG + + TurboJPEG() # fail fast if native lib is missing + + replay = TimedSensorReplay("unitree_go2_bigoffice/video") + frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] + codec = JpegCodec(quality=95) + except (ImportError, RuntimeError): + return None + + return Case( + name="jpeg", + codec=codec, + values=frames, + eq=_jpeg_eq, + ) + + +testcases = [ + c + for c in [_pickle_case(), _lcm_case(), _lz4_pickle_case(), _lz4_lcm_case(), _jpeg_case()] + if c is not None +] + + +@pytest.mark.parametrize("case", testcases, ids=lambda c: c.name) +class TestCodecRoundtrip: + """Every codec must perfectly roundtrip its values.""" + + def test_roundtrip_preserves_value(self, case: Case) -> None: + eq = case.eq or (lambda a, b: a == b) + for value in case.values: + encoded = case.codec.encode(value) + assert isinstance(encoded, bytes) + decoded = case.codec.decode(encoded) + assert eq(value, decoded), f"Roundtrip failed for {value!r}: got {decoded!r}" + + def test_encode_returns_nonempty_bytes(self, case: Case) -> None: + for value in case.values: + encoded = case.codec.encode(value) + assert len(encoded) > 0, f"Empty encoding for {value!r}" + + def test_different_values_produce_different_bytes(self, case: Case) -> None: + encodings = [case.codec.encode(v) for v in case.values] + assert len(set(encodings)) > 1, "All values encoded to identical bytes" + + +class TestCodecFor: + """codec_for() auto-selects the right codec.""" + + def test_none_returns_pickle(self) -> None: + assert isinstance(codec_for(None), PickleCodec) + + def test_unknown_type_returns_pickle(self) -> None: + assert isinstance(codec_for(dict), PickleCodec) + + def test_lcm_type_returns_lcm(self) -> None: + assert isinstance(codec_for(PoseStamped), LcmCodec) + + def test_image_type_returns_jpeg(self) -> None: + pytest.importorskip("turbojpeg") + from dimos.memory2.codecs.jpeg import JpegCodec + + assert isinstance(codec_for(Image), JpegCodec) diff --git a/dimos/memory2/conftest.py b/dimos/memory2/conftest.py new file mode 100644 index 0000000000..68cea71c2d --- /dev/null +++ b/dimos/memory2/conftest.py @@ -0,0 +1,89 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for memory2 tests.""" + +from __future__ import annotations + +import sqlite3 +import tempfile +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.store.sqlite import SqliteStore + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + from dimos.memory2.blobstore.base import BlobStore + from dimos.memory2.store.base import Store + + +@pytest.fixture +def memory_store() -> Iterator[MemoryStore]: + with MemoryStore() as store: + yield store + + +@pytest.fixture +def memory_session(memory_store: MemoryStore) -> Iterator[MemoryStore]: + """Alias: in the new architecture, the store IS the session.""" + yield memory_store + + +@pytest.fixture +def sqlite_store() -> Iterator[SqliteStore]: + with tempfile.NamedTemporaryFile(suffix=".db") as f: + store = SqliteStore(path=f.name) + with store: + yield store + + +@pytest.fixture +def sqlite_session(sqlite_store: SqliteStore) -> Iterator[SqliteStore]: + """Alias: in the new architecture, the store IS the session.""" + yield sqlite_store + + +@pytest.fixture(params=["memory_store", "sqlite_store"]) +def session(request: pytest.FixtureRequest) -> Store: + """Parametrized fixture that runs tests against both backends. + + Named 'session' to minimize test changes — tests use session.stream() which + now goes directly to Store.stream(). + """ + return request.getfixturevalue(request.param) + + +@pytest.fixture +def file_blob_store(tmp_path: Path) -> Iterator[FileBlobStore]: + with FileBlobStore(root=str(tmp_path / "blobs")) as store: + yield store + + +@pytest.fixture +def sqlite_blob_store() -> Iterator[SqliteBlobStore]: + conn = sqlite3.connect(":memory:") + with SqliteBlobStore(conn=conn) as store: + yield store + + +@pytest.fixture(params=["file_blob_store", "sqlite_blob_store"]) +def blob_store(request: pytest.FixtureRequest) -> BlobStore: + return request.getfixturevalue(request.param) diff --git a/dimos/memory2/embed.py b/dimos/memory2/embed.py new file mode 100644 index 0000000000..17b5b98a31 --- /dev/null +++ b/dimos/memory2/embed.py @@ -0,0 +1,79 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, TypeVar + +from dimos.memory2.transform import Transformer + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.observation import Observation + from dimos.models.embedding.base import EmbeddingModel + +T = TypeVar("T") + + +def _batched(it: Iterator[T], n: int) -> Iterator[list[T]]: + """Yield successive n-sized chunks from an iterator.""" + while True: + batch = list(islice(it, n)) + if not batch: + return + yield batch + + +class EmbedImages(Transformer[Any, Any]): + """Embed images using ``model.embed()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: + for batch in _batched(upstream, self.batch_size): + images = [obs.data for obs in batch] + embeddings = self.model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) + + +class EmbedText(Transformer[Any, Any]): + """Embed text using ``model.embed_text()``. + + Data type stays the same — observations are enriched with an + ``.embedding`` field, yielding :class:`EmbeddedObservation` instances. + """ + + def __init__(self, model: EmbeddingModel, batch_size: int = 32) -> None: + self.model = model + self.batch_size = batch_size + + def __call__(self, upstream: Iterator[Observation[Any]]) -> Iterator[Observation[Any]]: + for batch in _batched(upstream, self.batch_size): + texts = [str(obs.data) for obs in batch] + embeddings = self.model.embed_text(*texts) + if not isinstance(embeddings, list): + embeddings = [embeddings] + for obs, emb in zip(batch, embeddings, strict=False): + yield obs.derive(data=obs.data, embedding=emb) diff --git a/dimos/memory2/embeddings.md b/dimos/memory2/embeddings.md new file mode 100644 index 0000000000..9028c29f9d --- /dev/null +++ b/dimos/memory2/embeddings.md @@ -0,0 +1,148 @@ +# memory Embedding Design + +## Core Principle: Enrichment, Not Replacement + +The embedding annotates the observation — it doesn't replace `.data`. +In memory1, `.data` IS the embedding and you need `parent_id` + `project_to()` to get back to the source image. We avoid this entirely. + +## Observation Types + +```python +@dataclass +class Observation(Generic[T]): + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = ... + _loader: Callable[[], T] | None = None # lazy loading via blob store + +@dataclass +class EmbeddedObservation(Observation[T]): + embedding: Embedding | None = None # populated by Embed transformer + similarity: float | None = None # populated by .search() +``` + +`EmbeddedObservation` is a subclass — passes anywhere `Observation` is accepted (LSP). +Users who don't care about types just use `Observation`. Users who want precision annotate with `EmbeddedObservation`. + +`derive()` on `Observation` promotes to `EmbeddedObservation` if `embedding=` is passed. +`derive()` on `EmbeddedObservation` returns `EmbeddedObservation`, preserving the embedding unless explicitly replaced. + +## Embed Transformer + +`Embed` is `Transformer[T, T]` — same data type in and out. It populates `.embedding` on each observation: + +```python +class Embed(Transformer[T, T]): + def __init__(self, model: EmbeddingModel): + self.model = model + + def __call__(self, upstream): + for batch in batched(upstream, 32): + vecs = self.model.embed_batch([obs.data for obs in batch]) + for obs, vec in zip(batch, vecs): + yield obs.derive(data=obs.data, embedding=vec) +``` + +`Stream[Image]` stays `Stream[Image]` after embedding — `T` is about `.data`, not the observation subclass. + +## Search + +`.search(query_vec, k)` lives on `Stream` itself. Returns a new Stream filtered to top-k by cosine similarity: + +```python +query_vec = clip.embed_text("a cat in the kitchen") + +results = images.transform(Embed(clip)).search(query_vec, k=20).fetch() +# results[0].data → Image +# results[0].embedding → np.ndarray +# results[0].similarity → 0.93 + +# Chainable with other filters +results = images.transform(Embed(clip)) \ + .search(query_vec, k=50) \ + .after(one_hour_ago) \ + .near(kitchen_pose, 5.0) \ + .fetch() +``` + +## Backend Handles Storage Strategy + +The Backend composite decides how to route storage based on what it sees: + +- `append(image, ts=now, embedding=vec)` → backend routes: blob via BlobStore, vector via VectorStore, metadata via ObservationStore +- `append(image, ts=now)` → blob + metadata only (no embedding) +- `ListObservationStore`: stores metadata in-memory, brute-force cosine via MemoryVectorStore +- `SqliteObservationStore`: metadata in SQLite, vec0 side table for fast ANN search via SqliteVectorStore +- Future backends (Postgres/pgvector, Qdrant, etc.) do their thing + +Search is pushed down to the VectorStore. Stream just passes `.search()` calls through. + +## Projection / Lineage + +**Usually not needed.** Since `.data` IS the original data, search results give you the image directly. + +When a downstream transform replaces `.data` (e.g., Image → Detection), use temporal join to get back to the source: + +```python +detection = detections.first() +detection.data # → Detection +detection.ts # → timestamp preserved by derive() + +# Get the source image via temporal join +source_image = images.at(detection.ts).first() +``` + +## Multi-Modal + +**Same embedding space = same stream.** CLIP maps images and text to the same 512-d space: + +```python +unified = store.stream("clip_unified") + +for obs in images.transform(Embed(clip.vision)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "image"}, embedding=obs.embedding) + +for obs in logs.transform(Embed(clip.text)): + unified.append(obs.data, ts=obs.ts, + tags={"modality": "text"}, embedding=obs.embedding) + +results = unified.search(query_vec, k=20).fetch() +# results[i].tags["modality"] tells you what it is +``` + +**Different embedding spaces = different streams.** Can't mix CLIP and sentence-transformer vectors. + +## Chaining — Embedding as Cheap Pre-Filter + +```python +smoke_query = clip.embed_text("smoke or fire") + +detections = images.transform(Embed(clip)) \ + .search(smoke_query, k=100) \ + .transform(ExpensiveVLMDetector()) +# VLM only runs on 100 most promising frames + +# Smart transformer can use embedding directly +class SmartDetector(Transformer[Image, Detection]): + def __call__(self, upstream: Iterator[EmbeddedObservation[Image]]) -> ...: + for obs in upstream: + if obs.embedding @ self.query > 0.3: + yield obs.derive(data=self.detect(obs.data)) +``` + +## Text Search (FTS) — Separate Concern + +FTS is keyword-based, not embedding-based. Complementary, not competing: + +```python +# Keyword search via FTS5 +logs = store.stream("logs") +logs.search_text("motor fault").fetch() + +# Semantic search via embeddings +log_idx = logs.transform(Embed(sentence_model)).store("log_emb") +log_idx.search(model.embed("motor problems"), k=10).fetch() +``` diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md new file mode 100644 index 0000000000..e88561c283 --- /dev/null +++ b/dimos/memory2/intro.md @@ -0,0 +1,170 @@ +# Memory Intro + +## Quick start + +```python session=memory ansi=false no-result +from dimos.memory2.store.sqlite import SqliteStore + +store = SqliteStore(path="/tmp/memory_readme.db") +``` + + +```python session=memory ansi=false +logs = store.stream("logs", str) +print(logs) +``` + + +``` +Stream("logs") +``` + +Append observations: + +```python session=memory ansi=false +logs.append("Motor started", ts=1.0, tags={"level": "info"}) +logs.append("Joint 3 fault", ts=2.0, tags={"level": "error"}) +logs.append("Motor stopped", ts=3.0, tags={"level": "info"}) + +print(logs.summary()) +``` + + +``` +Stream("logs"): 3 items, 1970-01-01 00:00:01 — 1970-01-01 00:00:03 (2.0s) +``` + +## Filters + +Queries are lazy — chaining filters builds a pipeline without fetching: + +```python session=memory ansi=false +print(logs.at(1.0).before(5.0).tags(level="error")) +``` + + +``` +Stream("logs") | AtFilter(t=1.0, tolerance=1.0) | BeforeFilter(t=5.0) | TagsFilter(tags={'level': 'error'}) +``` + +Available filters: `.after(t)`, `.before(t)`, `.at(t)`, `.near(pose, radius)`, `.tags(**kv)`, `.filter(predicate)`, `.search(embedding, k)`, `.order_by(field)`, `.limit(k)`, `.offset(n)`. + +## Terminals + +Terminals materialize or consume the stream: + +```python session=memory ansi=false +print(logs.before(5.0).tags(level="error").fetch()) +``` + + +``` +[Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'})] +``` + +Available terminals: `.fetch()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.save(target)`. + +## Transforms + +`.map(fn)` transforms each observation, returning a new stream: + +```python session=memory ansi=false +print(logs.map(lambda obs: obs.data.upper()).first()) +``` + + +``` +MOTOR STARTED +``` + +## Live queries + +Live queries backfill existing matches, then emit new ones as they arrive: + +```python session=memory ansi=false +import time + +def emit_some_logs(): + last_ts = logs.last().ts + logs.append("Heartbeat ok", ts=last_ts + 1, pose=(3.0, 1.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Sensor fault", ts=last_ts + 2, pose=(4.1, 2.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + logs.append("Battery low: 30%", ts=last_ts + 3, pose=(5.3, 2.5, 0.0), tags={"level": "info"}) + time.sleep(0.1) + logs.append("Overtemp", ts=last_ts + 4, pose=(6.0, 3.0, 0.0), tags={"level": "error"}) + time.sleep(0.1) + + +with logs.tags(level="error").live() as errors: + sub = errors.subscribe(lambda obs: print(f"{obs.ts} - {obs.data}")) + emit_some_logs() + sub.dispose() + +``` + + +``` +2.0 - Joint 3 fault +5.0 - Sensor fault +7.0 - Overtemp +``` + +## Spatial + live + +Filters compose freely. Here `.near()` + `.live()` + `.map()` watches for logs near a physical location — backfilling past matches and tailing new ones: + +```python session=memory ansi=false +near_query = logs.near((5.0, 2.0), radius=2.0).live() +with near_query.map(lambda obs: f"near POI - {obs.data}") as logs_near: + with logs_near.subscribe(print): + emit_some_logs() +``` + + +``` +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +near POI - Sensor fault +near POI - Battery low: 30% +near POI - Overtemp +``` + +## Embeddings + +Use `EmbedText` transformer with CLIP to enrich observations with embeddings, then search by similarity: + +`.search(embedding, k)` returns the top-k most similar observations by cosine similarity: + +```python session=memory ansi=false +from dimos.models.embedding.clip import CLIPModel +from dimos.memory2.embed import EmbedText + +clip = CLIPModel() + +for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware problem"), k=3).fetch(): + print(f"{obs.similarity:.3f} {obs.data}") +``` + + +``` +0.897 Sensor fault +0.897 Sensor fault +0.887 Battery low: 30% +``` + +The embedded stream above was ephemeral — built on the fly for one query. To persist embeddings automatically as logs arrive, pipe a live stream through the transform into a stored stream: + +```python skip +import threading + +embedded_logs = store.stream("embedded_logs", str) +threading.Thread( + target=lambda: logs.live().transform(EmbedText(clip)).save(embedded_logs), + daemon=True, +).start() + +# every new log is now automatically embedded and stored +# embedded_logs.search(query, k=5).fetch() to query at any time +``` diff --git a/dimos/memory2/notes.md b/dimos/memory2/notes.md new file mode 100644 index 0000000000..8a9a05c30c --- /dev/null +++ b/dimos/memory2/notes.md @@ -0,0 +1,10 @@ + +```python +with db() as db: + with db.stream as image: + image.put(...) +``` + +DB specifies some general configuration for all sessions/streams. + +`db.stream` initializes these sessions? diff --git a/dimos/memory2/notifier/base.py b/dimos/memory2/notifier/base.py new file mode 100644 index 0000000000..022d26d4e0 --- /dev/null +++ b/dimos/memory2/notifier/base.py @@ -0,0 +1,62 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class NotifierConfig(BaseConfig): + pass + + +class Notifier(Configurable[NotifierConfig], Generic[T]): + """Push-notification for live observation delivery. + + Decouples the notification mechanism from storage. The built-in + ``SubjectNotifier`` handles same-process fan-out (thread-safe, zero + config). External implementations (Redis pub/sub, Postgres + LISTEN/NOTIFY, inotify) can be injected for cross-process use. + """ + + default_config: type[NotifierConfig] = NotifierConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + + @abstractmethod + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + """Register *buf* to receive new observations. Returns a + disposable that unsubscribes when disposed.""" + ... + + @abstractmethod + def notify(self, obs: Observation[T]) -> None: + """Fan out *obs* to all current subscribers.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/notifier/subject.py b/dimos/memory2/notifier/subject.py new file mode 100644 index 0000000000..d1b8d7f888 --- /dev/null +++ b/dimos/memory2/notifier/subject.py @@ -0,0 +1,70 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""In-memory fan-out notifier (same-process, thread-safe).""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from reactivex.disposable import Disposable + +from dimos.memory2.notifier.base import Notifier, NotifierConfig + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class SubjectNotifierConfig(NotifierConfig): + pass + + +class SubjectNotifier(Notifier[T]): + """In-memory fan-out notifier for same-process live notification. + + Thread-safe. ``notify()`` copies the subscriber list under the lock, + then iterates outside the lock to avoid deadlocks with slow consumers. + """ + + default_config = SubjectNotifierConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._subscribers: list[BackpressureBuffer[Observation[T]]] = [] + self._lock = threading.Lock() + + def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: + with self._lock: + self._subscribers.append(buf) + + def _unsubscribe() -> None: + with self._lock: + try: + self._subscribers.remove(buf) + except ValueError: + pass + + return Disposable(action=_unsubscribe) + + def notify(self, obs: Observation[T]) -> None: + with self._lock: + subs = list(self._subscribers) + for buf in subs: + buf.put(obs) diff --git a/dimos/memory2/observationstore/base.py b/dimos/memory2/observationstore/base.py new file mode 100644 index 0000000000..4d94889fb0 --- /dev/null +++ b/dimos/memory2/observationstore/base.py @@ -0,0 +1,73 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class ObservationStoreConfig(BaseConfig): + pass + + +class ObservationStore(Configurable[ObservationStoreConfig], CompositeResource, Generic[T]): + """Core metadata storage and query engine for observations. + + Handles only observation metadata storage, query pushdown, and count. + Blob/vector/live orchestration is handled by the concrete Backend class. + """ + + default_config: type[ObservationStoreConfig] = ObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def insert(self, obs: Observation[T]) -> int: + """Insert observation metadata, return assigned id.""" + ... + + @abstractmethod + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + """Execute query against metadata. Blobs are NOT loaded here.""" + ... + + @abstractmethod + def count(self, q: StreamQuery) -> int: ... + + @abstractmethod + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + """Batch fetch by id (for vector search results).""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/observationstore/memory.py b/dimos/memory2/observationstore/memory.py new file mode 100644 index 0000000000..529cd06394 --- /dev/null +++ b/dimos/memory2/observationstore/memory.py @@ -0,0 +1,80 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from dimos.memory2.observationstore.base import ObservationStore, ObservationStoreConfig + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") + + +class ListObservationStoreConfig(ObservationStoreConfig): + name: str = "" + + +class ListObservationStore(ObservationStore[T]): + """In-memory metadata store for experimentation. Thread-safe.""" + + default_config = ListObservationStoreConfig + config: ListObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._name = self.config.name + self._observations: list[Observation[T]] = [] + self._next_id = 0 + self._lock = threading.Lock() + + @property + def name(self) -> str: + return self._name + + def insert(self, obs: Observation[T]) -> int: + with self._lock: + obs.id = self._next_id + row_id = self._next_id + self._next_id += 1 + self._observations.append(obs) + return row_id + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + with self._lock: + snapshot = list(self._observations) + + # Text search — substring match + if q.search_text is not None: + needle = q.search_text.lower() + it: Iterator[Observation[T]] = ( + obs for obs in snapshot if needle in str(obs.data).lower() + ) + return q.apply(it) + + return q.apply(iter(snapshot)) + + def count(self, q: StreamQuery) -> int: + return sum(1 for _ in self.query(q)) + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + id_set = set(ids) + with self._lock: + return [obs for obs in self._observations if obs.id in id_set] diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py new file mode 100644 index 0000000000..5d680c540a --- /dev/null +++ b/dimos/memory2/observationstore/sqlite.py @@ -0,0 +1,444 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import re +import sqlite3 +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from pydantic import Field, model_validator + +from dimos.memory2.codecs.base import Codec +from dimos.memory2.observationstore.base import ObservationStore, ObservationStoreConfig +from dimos.memory2.type.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + NearFilter, + TagsFilter, + TimeRangeFilter, + _xyz, +) +from dimos.memory2.type.observation import _UNLOADED, Observation +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection + +if TYPE_CHECKING: + from collections.abc import Iterator + + from dimos.memory2.type.filter import Filter, StreamQuery + +T = TypeVar("T") + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _decompose_pose(pose: Any) -> tuple[float, ...] | None: + if pose is None: + return None + if hasattr(pose, "position"): + pos = pose.position + orient = getattr(pose, "orientation", None) + x, y, z = float(pos.x), float(pos.y), float(getattr(pos, "z", 0.0)) + if orient is not None: + return (x, y, z, float(orient.x), float(orient.y), float(orient.z), float(orient.w)) + return (x, y, z, 0.0, 0.0, 0.0, 1.0) + if isinstance(pose, (list, tuple)): + vals = [float(v) for v in pose] + while len(vals) < 7: + vals.append(0.0 if len(vals) < 6 else 1.0) + return tuple(vals[:7]) + return None + + +def _reconstruct_pose( + x: float | None, + y: float | None, + z: float | None, + qx: float | None, + qy: float | None, + qz: float | None, + qw: float | None, +) -> tuple[float, ...] | None: + if x is None: + return None + return (x, y or 0.0, z or 0.0, qx or 0.0, qy or 0.0, qz or 0.0, qw or 1.0) + + +def _compile_filter(f: Filter, stream: str, prefix: str = "") -> tuple[str, list[Any]] | None: + """Compile a filter to SQL WHERE clause. Returns None for non-pushable filters. + + ``stream`` is the raw stream name (for R*Tree table references). + ``prefix`` is a column qualifier (e.g. ``"meta."`` for JOIN queries). + """ + if isinstance(f, AfterFilter): + return (f"{prefix}ts > ?", [f.t]) + if isinstance(f, BeforeFilter): + return (f"{prefix}ts < ?", [f.t]) + if isinstance(f, TimeRangeFilter): + return (f"{prefix}ts >= ? AND {prefix}ts <= ?", [f.t1, f.t2]) + if isinstance(f, AtFilter): + return (f"ABS({prefix}ts - ?) <= ?", [f.t, f.tolerance]) + if isinstance(f, TagsFilter): + clauses = [] + params: list[Any] = [] + for k, v in f.tags.items(): + if not _IDENT_RE.match(k): + raise ValueError(f"Invalid tag key: {k!r}") + clauses.append(f"json_extract({prefix}tags, '$.{k}') = ?") + params.append(v) + return (" AND ".join(clauses), params) + if isinstance(f, NearFilter): + pose = f.pose + if pose is None: + return None + if hasattr(pose, "position"): + pose = pose.position + cx, cy, cz = _xyz(pose) + r = f.radius + # R*Tree bounding-box pre-filter + exact squared-distance check + rtree_sql = ( + f'{prefix}id IN (SELECT id FROM "{stream}_rtree" ' + f"WHERE x_min >= ? AND x_max <= ? " + f"AND y_min >= ? AND y_max <= ? " + f"AND z_min >= ? AND z_max <= ?)" + ) + dist_sql = ( + f"(({prefix}pose_x - ?) * ({prefix}pose_x - ?) + " + f"({prefix}pose_y - ?) * ({prefix}pose_y - ?) + " + f"({prefix}pose_z - ?) * ({prefix}pose_z - ?) <= ?)" + ) + return ( + f"{rtree_sql} AND {dist_sql}", + [ + cx - r, + cx + r, + cy - r, + cy + r, + cz - r, + cz + r, # R*Tree bbox + cx, + cx, + cy, + cy, + cz, + cz, + r * r, # squared distance + ], + ) + # PredicateFilter — not pushable + return None + + +def _compile_query( + query: StreamQuery, + table: str, + *, + join_blob: bool = False, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to SQL. + + Returns (sql, params, python_filters) where python_filters must be + applied as post-filters in Python. + """ + prefix = "meta." if join_blob else "" + if join_blob: + select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + else: + select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' + + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table, prefix) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = select + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + # ORDER BY + if query.order_field: + if not _IDENT_RE.match(query.order_field): + raise ValueError(f"Invalid order_field: {query.order_field!r}") + direction = "DESC" if query.order_desc else "ASC" + sql += f" ORDER BY {prefix}{query.order_field} {direction}" + else: + sql += f" ORDER BY {prefix}id ASC" + + # Only push LIMIT/OFFSET to SQL when there are no Python post-filters + if not python_filters: + if query.limit_val is not None: + if query.offset_val: + sql += f" LIMIT {query.limit_val} OFFSET {query.offset_val}" + else: + sql += f" LIMIT {query.limit_val}" + elif query.offset_val: + sql += f" LIMIT -1 OFFSET {query.offset_val}" + + return (sql, params, python_filters) + + +def _compile_count( + query: StreamQuery, + table: str, +) -> tuple[str, list[Any], list[Filter]]: + """Compile a StreamQuery to a COUNT SQL query.""" + where_parts: list[str] = [] + params: list[Any] = [] + python_filters: list[Filter] = [] + + for f in query.filters: + compiled = _compile_filter(f, table) + if compiled is not None: + sql_part, sql_params = compiled + where_parts.append(sql_part) + params.extend(sql_params) + else: + python_filters.append(f) + + sql = f'SELECT COUNT(*) FROM "{table}"' + if where_parts: + sql += " WHERE " + " AND ".join(where_parts) + + return (sql, params, python_filters) + + +class SqliteObservationStoreConfig(ObservationStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + name: str = "" + codec: Codec[Any] | None = Field(default=None, exclude=True) + blob_store_conn_match: bool = Field(default=False, exclude=True) + page_size: int = 256 + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteObservationStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteObservationStore(ObservationStore[T]): + """SQLite-backed metadata store for a single stream (table). + + Handles only metadata storage and query pushdown. + Blob/vector/live orchestration is handled by Backend. + + Supports two construction modes: + + - ``SqliteObservationStore(conn=conn, name="x", codec=...)`` — borrows an externally-managed connection. + - ``SqliteObservationStore(path="file.db", name="x", codec=...)`` — opens and owns its own connection. + """ + + default_config = SqliteObservationStoreConfig + config: SqliteObservationStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._name = self.config.name + self._codec = self.config.codec + self._blob_store_conn_match = self.config.blob_store_conn_match + self._page_size = self.config.page_size + self._lock = threading.Lock() + self._tag_indexes: set[str] = set() + self._pending_python_filters: list[Any] = [] + self._pending_query: StreamQuery | None = None + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + self._ensure_tables() + + def _ensure_tables(self) -> None: + """Create the metadata table and R*Tree index if they don't exist.""" + self._conn.execute( + f'CREATE TABLE IF NOT EXISTS "{self._name}" (' + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " ts REAL NOT NULL UNIQUE," + " pose_x REAL, pose_y REAL, pose_z REAL," + " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," + " tags BLOB DEFAULT (jsonb('{}'))" + ")" + ) + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{self._name}_rtree" USING rtree(' + " id," + " x_min, x_max," + " y_min, y_max," + " z_min, z_max" + ")" + ) + self._conn.commit() + + @property + def name(self) -> str: + return self._name + + @property + def _join_blobs(self) -> bool: + return self._blob_store_conn_match + + def _make_loader(self, row_id: int, blob_store: Any) -> Any: + name = self._name + codec = self._codec + assert codec is not None, "codec is required for data loading" + + def loader() -> Any: + raw = blob_store.get(name, row_id) + return codec.decode(raw) + + return loader + + def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: + if has_blob: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + else: + row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + blob_data = None + + pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) + tags = json.loads(tags_json) if tags_json else {} + + if has_blob and blob_data is not None: + assert self._codec is not None, "codec is required for data loading" + data = self._codec.decode(blob_data) + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=data) + + return Observation( + id=row_id, + ts=ts, + pose=pose, + tags=tags, + _data=_UNLOADED, + ) + + def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: + for key in tags: + if key not in self._tag_indexes and _IDENT_RE.match(key): + self._conn.execute( + f'CREATE INDEX IF NOT EXISTS "{self._name}_tag_{key}" ' + f"ON \"{self._name}\"(json_extract(tags, '$.{key}'))" + ) + self._tag_indexes.add(key) + + def insert(self, obs: Observation[T]) -> int: + pose = _decompose_pose(obs.pose) + tags_json = json.dumps(obs.tags) if obs.tags else "{}" + + with self._lock: + if obs.tags: + self._ensure_tag_indexes(obs.tags) + if pose: + px, py, pz, qx, qy, qz, qw = pose + else: + px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] + + cur = self._conn.execute( + f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + ) + row_id = cur.lastrowid + assert row_id is not None + + # R*Tree spatial index + if pose: + self._conn.execute( + f'INSERT INTO "{self._name}_rtree" (id, x_min, x_max, y_min, y_max, z_min, z_max) ' + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (row_id, px, px, py, py, pz, pz), + ) + + # Do NOT commit here — Backend calls commit() after blob/vector writes + + return row_id + + def commit(self) -> None: + self._conn.commit() + + def rollback(self) -> None: + self._conn.rollback() + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + if q.search_text is not None: + raise NotImplementedError("search_text is not supported by SqliteObservationStore") + + join = self._join_blobs + sql, params, python_filters = _compile_query(q, self._name, join_blob=join) + + cur = self._conn.execute(sql, params) + cur.arraysize = self._page_size + it: Iterator[Observation[T]] = (self._row_to_obs(r, has_blob=join) for r in cur) + + # Don't apply python post-filters here — Backend._attach_loaders must + # run first so that obs.data works for PredicateFilter etc. + # Store them so Backend can retrieve and apply after attaching loaders. + self._pending_python_filters = python_filters + self._pending_query = q + + return it + + def count(self, q: StreamQuery) -> int: + if q.search_vec: + # Delegate to Backend for vector-aware counting + raise NotImplementedError("count with search_vec must go through Backend") + + sql, params, python_filters = _compile_count(q, self._name) + if python_filters: + return sum(1 for _ in self.query(q)) + + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + if not ids: + return [] + join = self._join_blobs + placeholders = ",".join("?" * len(ids)) + if join: + sql = ( + f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " + f'FROM "{self._name}" AS meta ' + f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' + f"WHERE meta.id IN ({placeholders})" + ) + else: + sql = ( + f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " + f'FROM "{self._name}" WHERE id IN ({placeholders})' + ) + + rows = self._conn.execute(sql, ids).fetchall() + return [self._row_to_obs(r, has_blob=join) for r in rows] + + def stop(self) -> None: + super().stop() diff --git a/dimos/memory2/registry.py b/dimos/memory2/registry.py new file mode 100644 index 0000000000..4e4c28da86 --- /dev/null +++ b/dimos/memory2/registry.py @@ -0,0 +1,81 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stream registry: persists fully-resolved backend config per stream.""" + +from __future__ import annotations + +import importlib +import json +import sqlite3 +from typing import Any + +from pydantic import Field + +from dimos.protocol.service.spec import BaseConfig, Configurable + + +def qual(cls: type) -> str: + """Fully qualified class name, e.g. 'dimos.memory2.blobstore.sqlite.SqliteBlobStore'.""" + return f"{cls.__module__}.{cls.__qualname__}" + + +def deserialize_component(data: dict[str, Any]) -> Any: + """Instantiate a component from its ``{"class": ..., "config": ...}`` dict.""" + module_path, _, cls_name = data["class"].rpartition(".") + mod = importlib.import_module(module_path) + cls = getattr(mod, cls_name) + return cls(**data["config"]) + + +class RegistryStoreConfig(BaseConfig): + conn: sqlite3.Connection = Field(exclude=True) + + +class RegistryStore(Configurable[RegistryStoreConfig]): + """SQLite persistence for stream name -> config JSON.""" + + default_config: type[RegistryStoreConfig] = RegistryStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn + self._conn.execute( + "CREATE TABLE IF NOT EXISTS _streams (" + " name TEXT PRIMARY KEY," + " config TEXT NOT NULL" + ")" + ) + self._conn.commit() + + def get(self, name: str) -> dict[str, Any] | None: + row = self._conn.execute("SELECT config FROM _streams WHERE name = ?", (name,)).fetchone() + if row is None: + return None + return json.loads(row[0]) # type: ignore[no-any-return] + + def put(self, name: str, config: dict[str, Any]) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO _streams (name, config) VALUES (?, ?)", + (name, json.dumps(config)), + ) + self._conn.commit() + + def delete(self, name: str) -> None: + self._conn.execute("DELETE FROM _streams WHERE name = ?", (name,)) + self._conn.commit() + + def list_streams(self) -> list[str]: + rows = self._conn.execute("SELECT name FROM _streams").fetchall() + return [r[0] for r in rows] diff --git a/dimos/memory2/store/README.md b/dimos/memory2/store/README.md new file mode 100644 index 0000000000..ff18640c0b --- /dev/null +++ b/dimos/memory2/store/README.md @@ -0,0 +1,130 @@ +# store — Store implementations + +Metadata index backends for memory. Each index implements the `ObservationStore` protocol to provide observation metadata storage with query support. The concrete `Backend` class handles orchestration (blob, vector, live) on top of any index. + +## Existing implementations + +| ObservationStore | File | Status | Storage | +|-----------------|-------------|----------|-------------------------------------| +| `ListObservationStore` | `memory.py` | Complete | In-memory lists, brute-force search | +| `SqliteObservationStore` | `sqlite.py` | Complete | SQLite (WAL, R*Tree, vec0) | + +## Writing a new index + +### 1. Implement the ObservationStore protocol + +```python +from dimos.memory2.observationstore.base import ObservationStore +from dimos.memory2.type.filter import StreamQuery +from dimos.memory2.type.observation import Observation + +class MyObservationStore(Generic[T]): + def __init__(self, name: str) -> None: + self._name = name + + @property + def name(self) -> str: + return self._name + + def insert(self, obs: Observation[T]) -> int: + """Insert observation metadata, return assigned id.""" + row_id = self._next_id + self._next_id += 1 + # ... persist metadata ... + return row_id + + def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + """Yield observations matching the query.""" + # The index handles metadata query fields: + # q.filters — list of Filter objects (each has .matches(obs)) + # q.order_field — sort field name (e.g. "ts") + # q.order_desc — sort direction + # q.limit_val — max results + # q.offset_val — skip first N + # q.search_text — substring text search + ... + + def count(self, q: StreamQuery) -> int: + """Count matching observations.""" + ... + + def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: + """Batch fetch by id (for vector search results).""" + ... +``` + +`ObservationStore` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. + +### 2. Create a Store subclass + +```python +from dimos.memory2.backend import Backend +from dimos.memory2.codecs.base import codec_for +from dimos.memory2.store.base import Store + +class MyStore(Store): + def _create_backend( + self, name: str, payload_type: type | None = None, **config: Any + ) -> Backend: + index = MyObservationStore(name) + codec = codec_for(payload_type) + return Backend( + index=index, + codec=codec, + blob_store=config.get("blob_store"), + vector_store=config.get("vector_store"), + notifier=config.get("notifier"), + eager_blobs=config.get("eager_blobs", False), + ) + + def list_streams(self) -> list[str]: + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + self._streams.pop(name, None) +``` + +The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode → insert → store blob → index vector → notify) so your index only needs to handle metadata. + +### 3. Add to the grid test + +In `test_impl.py`, add your store to the fixture so all standard tests run against it: + +```python +@pytest.fixture(params=["memory", "sqlite", "myindex"]) +def store(request, tmp_path): + if request.param == "myindex": + return MyStore(...) + ... +``` + +Use `pytest.mark.xfail` for features not yet implemented — the grid test covers: append, fetch, iterate, count, first/last, exists, all filters, ordering, limit/offset, embeddings, text search. + +### Query contract + +The index must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the index never needs to deal with them. + +`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStorees can use it in three ways: + +**Full delegation** — simplest, good enough for in-memory indexes: +```python +def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + return q.apply(iter(self._data)) +``` + +**Partial push-down** — handle some operations natively, delegate the rest: +```python +def query(self, q: StreamQuery) -> Iterator[Observation[T]]: + # Handle filters and ordering in SQL + rows = self._sql_query(q.filters, q.order_field, q.order_desc) + # Delegate remaining operations to Python + remaining = StreamQuery( + search_text=q.search_text, + offset_val=q.offset_val, limit_val=q.limit_val, + ) + return remaining.apply(iter(rows)) +``` + +**Full push-down** — translate everything to native queries (SQL WHERE, FTS5 MATCH) without calling `apply()` at all. + +For filters, each `Filter` object has a `.matches(obs) -> bool` method that indexes can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/store/base.py b/dimos/memory2/store/base.py new file mode 100644 index 0000000000..cf571f23b0 --- /dev/null +++ b/dimos/memory2/store/base.py @@ -0,0 +1,166 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, TypeVar, cast + +from dimos.core.resource import CompositeResource +from dimos.memory2.backend import Backend +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.codecs.base import Codec, codec_for, codec_from_id +from dimos.memory2.notifier.base import Notifier +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.observationstore.base import ObservationStore +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.stream import Stream +from dimos.memory2.vectorstore.base import VectorStore +from dimos.protocol.service.spec import BaseConfig, Configurable + +T = TypeVar("T") + + +class StreamAccessor: + """Attribute-style access: ``store.streams.name`` -> ``store.stream(name)``.""" + + __slots__ = ("_store",) + + def __init__(self, store: Store) -> None: + object.__setattr__(self, "_store", store) + + def __getattr__(self, name: str) -> Stream[Any]: + if name.startswith("_"): + raise AttributeError(name) + store: Store = object.__getattribute__(self, "_store") + if name not in store.list_streams(): + raise AttributeError(f"No stream {name!r}. Available: {store.list_streams()}") + return store.stream(name) + + def __getitem__(self, name: str) -> Stream[Any]: + store: Store = object.__getattribute__(self, "_store") + if name not in store.list_streams(): + raise KeyError(name) + return store.stream(name) + + def __dir__(self) -> list[str]: + store: Store = object.__getattribute__(self, "_store") + return store.list_streams() + + def __repr__(self) -> str: + names = object.__getattribute__(self, "_store").list_streams() + return f"StreamAccessor({names})" + + +class StoreConfig(BaseConfig): + """Store-level config. These are defaults inherited by all streams. + + Component fields accept either a class (instantiated per-stream) or + a live instance (used directly). Classes are the default; instances + are for overrides (e.g. spy stores in tests, shared external stores). + """ + + observation_store: type[ObservationStore] | ObservationStore | None = None # type: ignore[type-arg] + blob_store: type[BlobStore] | BlobStore | None = None + vector_store: type[VectorStore] | VectorStore | None = None + notifier: type[Notifier] | Notifier | None = None # type: ignore[type-arg] + eager_blobs: bool = False + + +class Store(Configurable[StoreConfig], CompositeResource): + """Top-level entry point — wraps a storage location (file, URL, etc.). + + Store directly manages streams. No Session layer. + """ + + default_config: type[StoreConfig] = StoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + self._streams: dict[str, Stream[Any]] = {} + + @property + def streams(self) -> StreamAccessor: + """Attribute-style access to streams: ``store.streams.name``.""" + return StreamAccessor(self) + + @staticmethod + def _resolve_codec( + payload_type: type[Any] | None, raw_codec: Codec[Any] | str | None + ) -> Codec[Any]: + if isinstance(raw_codec, Codec): + return raw_codec + if isinstance(raw_codec, str): + module = ( + f"{payload_type.__module__}.{payload_type.__qualname__}" + if payload_type + else "builtins.object" + ) + return codec_from_id(raw_codec, module) + return codec_for(payload_type) + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + """Create a Backend for the named stream. Called once per stream name.""" + codec = self._resolve_codec(payload_type, config.pop("codec", None)) + + # Instantiate or use provided instances + obs = config.pop("observation_store", self.config.observation_store) + if obs is None or isinstance(obs, type): + obs = (obs or ListObservationStore)(name=name) + obs.start() + + bs = config.pop("blob_store", self.config.blob_store) + if isinstance(bs, type): + bs = bs() + bs.start() + + vs = config.pop("vector_store", self.config.vector_store) + if isinstance(vs, type): + vs = vs() + vs.start() + + notifier = config.pop("notifier", self.config.notifier) + if notifier is None or isinstance(notifier, type): + notifier = (notifier or SubjectNotifier)() + + return Backend( + metadata_store=obs, + codec=codec, + blob_store=bs, + vector_store=vs, + notifier=notifier, + eager_blobs=config.get("eager_blobs", False), + ) + + def stream(self, name: str, payload_type: type[T] | None = None, **overrides: Any) -> Stream[T]: + """Get or create a named stream. Returns the same Stream on repeated calls. + + Per-stream ``overrides`` (e.g. ``blob_store=``, ``codec=``) are merged + on top of the store-level defaults from :class:`StoreConfig`. + """ + if name not in self._streams: + resolved = {**self.config.model_dump(exclude_none=True), **overrides} + backend = self._create_backend(name, payload_type, **resolved) + self._streams[name] = Stream(source=backend) + return cast("Stream[T]", self._streams[name]) + + def list_streams(self) -> list[str]: + """Return names of all streams in this store.""" + return list(self._streams.keys()) + + def delete_stream(self, name: str) -> None: + """Delete a stream by name (from cache and underlying storage).""" + self._streams.pop(name, None) diff --git a/dimos/memory2/store/memory.py b/dimos/memory2/store/memory.py new file mode 100644 index 0000000000..6aecde29dd --- /dev/null +++ b/dimos/memory2/store/memory.py @@ -0,0 +1,21 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.store.base import Store + + +class MemoryStore(Store): + """In-memory store for experimentation.""" + + pass diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py new file mode 100644 index 0000000000..b655e0a8bc --- /dev/null +++ b/dimos/memory2/store/sqlite.py @@ -0,0 +1,217 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from dimos.memory2.backend import Backend +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore +from dimos.memory2.codecs.base import codec_id +from dimos.memory2.observationstore.sqlite import SqliteObservationStore +from dimos.memory2.registry import RegistryStore, deserialize_component, qual +from dimos.memory2.store.base import Store, StoreConfig +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier +from dimos.memory2.vectorstore.base import VectorStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + +class SqliteStoreConfig(StoreConfig): + """Config for SQLite-backed store.""" + + path: str = "memory.db" + page_size: int = 256 + + +class SqliteStore(Store): + """Store backed by a SQLite database file.""" + + default_config = SqliteStoreConfig + config: SqliteStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._registry_conn = self._open_connection() + self._registry = RegistryStore(conn=self._registry_conn) + + def _open_connection(self) -> sqlite3.Connection: + """Open a new WAL-mode connection with sqlite-vec loaded.""" + disposable, connection = open_disposable_sqlite_connection(self.config.path) + self.register_disposables(disposable) + return connection + + def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: + """Reconstruct a Backend from a stored config dict.""" + from dimos.memory2.codecs.base import codec_from_id + + payload_module = stored["payload_module"] + codec = codec_from_id(stored["codec_id"], payload_module) + eager_blobs = stored.get("eager_blobs", False) + page_size = stored.get("page_size", self.config.page_size) + + backend_conn = self._open_connection() + + # Reconstruct components from serialized config + bs_data = stored.get("blob_store") + if bs_data is not None: + bs_cfg = bs_data.get("config", {}) + if bs_cfg.get("path") is None and bs_data["class"] == qual(SqliteBlobStore): + bs: Any = SqliteBlobStore(conn=backend_conn) + else: + bs = deserialize_component(bs_data) + else: + bs = SqliteBlobStore(conn=backend_conn) + bs.start() + + vs_data = stored.get("vector_store") + if vs_data is not None: + vs_cfg = vs_data.get("config", {}) + if vs_cfg.get("path") is None and vs_data["class"] == qual(SqliteVectorStore): + vs: Any = SqliteVectorStore(conn=backend_conn) + else: + vs = deserialize_component(vs_data) + else: + vs = SqliteVectorStore(conn=backend_conn) + vs.start() + + notifier_data = stored.get("notifier") + if notifier_data is not None: + notifier = deserialize_component(notifier_data) + else: + from dimos.memory2.notifier.subject import SubjectNotifier + + notifier = SubjectNotifier() + + blob_store_conn_match = isinstance(bs, SqliteBlobStore) and bs._conn is backend_conn + + metadata_store: SqliteObservationStore[Any] = SqliteObservationStore( + conn=backend_conn, + name=name, + codec=codec, + blob_store_conn_match=blob_store_conn_match and eager_blobs, + page_size=page_size, + ) + metadata_store.start() + + backend: Backend[Any] = Backend( + metadata_store=metadata_store, + codec=codec, + blob_store=bs, + vector_store=vs, + notifier=notifier, + eager_blobs=eager_blobs, + ) + return backend + + @staticmethod + def _serialize_backend( + backend: Backend[Any], payload_module: str, page_size: int + ) -> dict[str, Any]: + """Serialize a backend's config for registry storage.""" + cfg: dict[str, Any] = { + "payload_module": payload_module, + "codec_id": codec_id(backend.codec), + "eager_blobs": backend.eager_blobs, + "page_size": page_size, + } + if backend.blob_store is not None: + cfg["blob_store"] = backend.blob_store.serialize() + if backend.vector_store is not None: + cfg["vector_store"] = backend.vector_store.serialize() + cfg["notifier"] = backend.notifier.serialize() + return cfg + + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + validate_identifier(name) + + stored = self._registry.get(name) + + if stored is not None: + # Load path: validate type, assemble from stored config + if payload_type is not None: + actual_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + if actual_module != stored["payload_module"]: + raise ValueError( + f"Stream {name!r} was created with type {stored['payload_module']}, " + f"but opened with {actual_module}" + ) + return self._assemble_backend(name, stored) + + # Create path: inject conn-shared defaults, then delegate to base + if payload_type is None: + raise TypeError(f"Stream {name!r} does not exist yet — payload_type is required") + + backend_conn = self._open_connection() + + # Inject conn-shared instances unless user provided overrides + if not isinstance(config.get("blob_store"), BlobStore): + bs = SqliteBlobStore(conn=backend_conn) + bs.start() + config["blob_store"] = bs + if not isinstance(config.get("vector_store"), VectorStore): + vs = SqliteVectorStore(conn=backend_conn) + vs.start() + config["vector_store"] = vs + + # Resolve codec early — needed for SqliteObservationStore + codec = self._resolve_codec(payload_type, config.get("codec")) + config["codec"] = codec + + # Create SqliteObservationStore with conn-sharing + bs = config["blob_store"] + blob_conn_match = isinstance(bs, SqliteBlobStore) and bs._conn is backend_conn + eager_blobs = config.get("eager_blobs", False) + obs_store: SqliteObservationStore[Any] = SqliteObservationStore( + conn=backend_conn, + name=name, + codec=codec, + blob_store_conn_match=blob_conn_match and eager_blobs, + page_size=config.pop("page_size", self.config.page_size), + ) + obs_store.start() + config["observation_store"] = obs_store + + backend = super()._create_backend(name, payload_type, **config) + + # Persist to registry + payload_module = f"{payload_type.__module__}.{payload_type.__qualname__}" + self._registry.put( + name, + self._serialize_backend( + backend, payload_module, config["observation_store"].config.page_size + ), + ) + + return backend + + def list_streams(self) -> list[str]: + db_names = set(self._registry.list_streams()) + return sorted(db_names | set(self._streams.keys())) + + def delete_stream(self, name: str) -> None: + super().delete_stream(name) + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_blob"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_vec"') + self._registry_conn.execute(f'DROP TABLE IF EXISTS "{name}_rtree"') + self._registry.delete(name) + + def stop(self) -> None: + super().stop() + self._registry_conn.close() diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py new file mode 100644 index 0000000000..545d387c32 --- /dev/null +++ b/dimos/memory2/stream.py @@ -0,0 +1,363 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.core.resource import Resource +from dimos.memory2.buffer import BackpressureBuffer, KeepLast +from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer +from dimos.memory2.type.filter import ( + AfterFilter, + AtFilter, + BeforeFilter, + Filter, + NearFilter, + PredicateFilter, + StreamQuery, + TagsFilter, + TimeRangeFilter, +) +from dimos.memory2.type.observation import EmbeddedObservation, Observation + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + import reactivex + from reactivex.abc import DisposableBase, ObserverBase + + from dimos.memory2.backend import Backend + from dimos.models.embedding.base import Embedding + +T = TypeVar("T") +R = TypeVar("R") + + +class Stream(Resource, Generic[T]): + """Lazy, pull-based stream over observations. + + Every filter/transform method returns a new Stream — no computation + happens until iteration. Backends handle query application for stored + data; transform sources apply filters as Python predicates. + + Implements Resource so live streams can be cleanly stopped via + ``stop()`` or used as a context manager. + """ + + def __init__( + self, + source: Backend[T] | Stream[Any], + *, + xf: Transformer[Any, T] | None = None, + query: StreamQuery = StreamQuery(), + ) -> None: + self._source = source + self._xf = xf + self._query = query + + def start(self) -> None: + pass + + def stop(self) -> None: + """Close the live buffer (if any), unblocking iteration.""" + buf = self._query.live_buffer + if buf is not None: + buf.close() + if isinstance(self._source, Stream): + self._source.stop() + + def __str__(self) -> str: + # Walk the source chain to collect (xf, query) pairs + chain: list[tuple[Any, StreamQuery]] = [] + current: Any = self + while isinstance(current, Stream): + chain.append((current._xf, current._query)) + current = current._source + chain.reverse() # innermost first + + # current is the Backend + name = getattr(current, "name", "?") + result = f'Stream("{name}")' + + for xf, query in chain: + if xf is not None: + result += f" -> {xf}" + q_str = str(query) + if q_str: + result += f" | {q_str}" + + return result + + def is_live(self) -> bool: + """True if this stream (or any ancestor in the chain) is in live mode.""" + if self._query.live_buffer is not None: + return True + if isinstance(self._source, Stream): + return self._source.is_live() + return False + + def __iter__(self) -> Iterator[Observation[T]]: + return self._build_iter() + + def _build_iter(self) -> Iterator[Observation[T]]: + if isinstance(self._source, Stream): + return self._iter_transform() + # Backend handles all query application (including live if requested) + return self._source.iterate(self._query) + + def _iter_transform(self) -> Iterator[Observation[T]]: + """Iterate a transform source, applying query filters in Python.""" + assert isinstance(self._source, Stream) and self._xf is not None + it: Iterator[Observation[T]] = self._xf(iter(self._source)) + return self._query.apply(it, live=self.is_live()) + + def _replace_query(self, **overrides: Any) -> Stream[T]: + q = self._query + new_q = StreamQuery( + filters=overrides.get("filters", q.filters), + order_field=overrides.get("order_field", q.order_field), + order_desc=overrides.get("order_desc", q.order_desc), + limit_val=overrides.get("limit_val", q.limit_val), + offset_val=overrides.get("offset_val", q.offset_val), + live_buffer=overrides.get("live_buffer", q.live_buffer), + search_vec=overrides.get("search_vec", q.search_vec), + search_k=overrides.get("search_k", q.search_k), + search_text=overrides.get("search_text", q.search_text), + ) + return Stream(self._source, xf=self._xf, query=new_q) + + def _with_filter(self, f: Filter) -> Stream[T]: + return self._replace_query(filters=(*self._query.filters, f)) + + def after(self, t: float) -> Stream[T]: + return self._with_filter(AfterFilter(t)) + + def before(self, t: float) -> Stream[T]: + return self._with_filter(BeforeFilter(t)) + + def time_range(self, t1: float, t2: float) -> Stream[T]: + return self._with_filter(TimeRangeFilter(t1, t2)) + + def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: + return self._with_filter(AtFilter(t, tolerance)) + + def near(self, pose: Any, radius: float) -> Stream[T]: + return self._with_filter(NearFilter(pose, radius)) + + def tags(self, **tags: Any) -> Stream[T]: + return self._with_filter(TagsFilter(tags)) + + def order_by(self, field: str, desc: bool = False) -> Stream[T]: + return self._replace_query(order_field=field, order_desc=desc) + + def limit(self, k: int) -> Stream[T]: + return self._replace_query(limit_val=k) + + def offset(self, n: int) -> Stream[T]: + return self._replace_query(offset_val=n) + + def search(self, query: Embedding, k: int) -> Stream[T]: + """Return top-k observations by cosine similarity to *query*. + + The backend handles the actual computation. ListObservationStore does + brute-force cosine; SqliteObservationStore pushes down to vec0. + """ + return self._replace_query(search_vec=query, search_k=k) + + def search_text(self, text: str) -> Stream[T]: + """Filter observations whose data contains *text*. + + ListObservationStore does case-insensitive substring match; + SqliteObservationStore (future) pushes down to FTS5. + """ + return self._replace_query(search_text=text) + + def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: + """Filter by arbitrary predicate on the full Observation.""" + return self._with_filter(PredicateFilter(pred)) + + def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[R]: + """Transform each observation's data via callable.""" + return self.transform(FnTransformer(lambda obs: fn(obs))) + + def transform( + self, + xf: Transformer[T, R] | Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]], + ) -> Stream[R]: + """Wrap this stream with a transformer. Returns a new lazy Stream. + + Accepts a ``Transformer`` subclass or a bare callable / generator + function with the same ``Iterator[Obs] → Iterator[Obs]`` signature:: + + def detect(upstream): + for obs in upstream: + yield obs.derive(data=run_detector(obs.data)) + + images.transform(detect).save(detections) + """ + if not isinstance(xf, Transformer): + xf = FnIterTransformer(xf) + return Stream(source=self, xf=xf, query=StreamQuery()) + + def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: + """Return a stream whose iteration never ends — backfill then live tail. + + All backends support live mode via their built-in ``Notifier``. + Call .live() before .transform(), not after. + + Default buffer: KeepLast(). The backend handles subscription, dedup, + and backpressure — how it does so is its business. + """ + if isinstance(self._source, Stream): + raise TypeError( + "Cannot call .live() on a transform stream. " + "Call .live() on the source stream, then .transform()." + ) + buf = buffer if buffer is not None else KeepLast() + return self._replace_query(live_buffer=buf) + + def save(self, target: Stream[T]) -> Stream[T]: + """Sync terminal: iterate self, append each obs to target's backend. + + Returns the target stream for continued querying. + """ + if isinstance(target._source, Stream): + raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + backend = target._source + for obs in self: + backend.append(obs) + return target + + def fetch(self) -> list[Observation[T]]: + """Materialize all observations into a list.""" + if self.is_live(): + raise TypeError( + ".fetch() on a live stream would block forever. " + "Use .drain() or .save(target) instead." + ) + return list(self) + + def first(self) -> Observation[T]: + """Return the first matching observation.""" + it = iter(self.limit(1)) + try: + return next(it) + except StopIteration: + raise LookupError("No matching observation") from None + + def last(self) -> Observation[T]: + """Return the last matching observation (by timestamp).""" + return self.order_by("ts", desc=True).first() + + def count(self) -> int: + """Count matching observations.""" + if not isinstance(self._source, Stream): + return self._source.count(self._query) + if self.is_live(): + raise TypeError(".count() on a live transform stream would block forever.") + return sum(1 for _ in self) + + def exists(self) -> bool: + """Check if any matching observation exists.""" + return next(iter(self.limit(1)), None) is not None + + def get_time_range(self) -> tuple[float, float]: + """Return (min_ts, max_ts) for matching observations.""" + first = self.first() + last = self.last() + return (first.ts, last.ts) + + def summary(self) -> str: + """Return a short human-readable summary: count, time range, duration.""" + from datetime import datetime, timezone + + n = self.count() + if n == 0: + return f"{self}: empty" + + (t0, t1) = self.get_time_range() + + fmt = "%Y-%m-%d %H:%M:%S" + dt0 = datetime.fromtimestamp(t0, tz=timezone.utc).strftime(fmt) + dt1 = datetime.fromtimestamp(t1, tz=timezone.utc).strftime(fmt) + dur = t1 - t0 + return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + + def drain(self) -> int: + """Consume all observations, discarding results. Returns count consumed. + + Use for side-effect pipelines (e.g. live embed-and-store) where you + don't need to collect results in memory. + """ + n = 0 + for _ in self: + n += 1 + return n + + def observable(self) -> reactivex.Observable[Observation[T]]: + """Convert this stream to an RxPY Observable. + + Iteration is scheduled on the dimos thread pool so subscribe() never + blocks the calling thread. + """ + import reactivex + import reactivex.operators as ops + + from dimos.utils.threadpool import get_scheduler + + return reactivex.from_iterable(self).pipe( + ops.subscribe_on(get_scheduler()), + ) + + def subscribe( + self, + on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, + on_error: Callable[[Exception], None] | None = None, + on_completed: Callable[[], None] | None = None, + ) -> DisposableBase: + """Subscribe to this stream as an RxPY Observable.""" + return self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + + def append( + self, + payload: T, + *, + ts: float | None = None, + pose: Any | None = None, + tags: dict[str, Any] | None = None, + embedding: Embedding | None = None, + ) -> Observation[T]: + """Append to the backing store. Only works if source is a Backend.""" + if isinstance(self._source, Stream): + raise TypeError("Cannot append to a transform stream. Append to the source stream.") + _ts = ts if ts is not None else time.time() + _tags = tags or {} + if embedding is not None: + obs: Observation[T] = EmbeddedObservation( + id=-1, + ts=_ts, + pose=pose, + tags=_tags, + _data=payload, + embedding=embedding, + ) + else: + obs = Observation(id=-1, ts=_ts, pose=pose, tags=_tags, _data=payload) + return self._source.append(obs) diff --git a/dimos/memory2/streaming.md b/dimos/memory2/streaming.md new file mode 100644 index 0000000000..fd7f5519a1 --- /dev/null +++ b/dimos/memory2/streaming.md @@ -0,0 +1,109 @@ +# Stream evaluation model + +Stream methods fall into three categories: **lazy**, **materializing**, and **terminal**. The distinction matters for live (infinite) streams. + +`is_live()` walks the source chain to detect live mode — any stream whose ancestor called `.live()` returns `True`. +All materializing operations and unsafe terminals check this and raise `TypeError` immediately rather than silently hanging. + +## Lazy (streaming) + +These return generators — each observation flows through one at a time. Safe with live/infinite streams. No internal buffering between stages. + +| Method | How | +|---------------------------------------------------------------------------|-------------------------------------------------| +| `.after()` `.before()` `.time_range()` `.at()` `.near()` `.filter_tags()` | Filter predicates — skip non-matching obs | +| `.filter(pred)` | Same, user-defined predicate | +| `.transform(xf_or_fn)` / `.map(fn)` | Generator — yields transformed obs one by one | +| `.search_text(text)` | Generator — substring match filter | +| `.limit(k)` | `islice` — stops after k | +| `.offset(n)` | `islice` — skips first n | +| `.live()` | Enables live tail (backfill then block for new) | + +These compose freely. A chain like `.after(t).filter(pred).transform(xf).limit(10)` pulls lazily — the source only produces what the consumer asks for. + +## Materializing (collect-then-process) + +These **must consume the entire upstream** before producing output. On a live stream, they raise `TypeError` immediately. + +| Method | Why | Live behaviour | +|--------------------|----------------------------------------------|----------------| +| `.search(vec, k)` | Cosine-ranks all observations, returns top-k | TypeError | +| `.order_by(field)` | `sorted(list(it))` — needs all items to sort | TypeError | + +On a backend-backed stream (not a transform), both are pushed down to the backend which handles them on its own data structure (snapshot). The guard only fires when these appear on a **transform stream** whose upstream is live — detected via `is_live()`. + +### Rejected patterns (raise TypeError) + +```python +# TypeError: search requires finite data +stream.live().transform(Embed(model)).search(vec, k=5) + +# TypeError: order_by requires finite data +stream.live().transform(xf).order_by("ts", desc=True) + +# TypeError (via order_by): last() calls order_by internally +stream.live().transform(xf).last() +``` + +### Safe equivalents + +```python +# Search the stored data, not the live tail +results = stream.search(vec, k=5).fetch() + +# First works fine (uses limit(1), no materialization) +obs = stream.live().transform(xf).first() +``` + +## Terminal (consume the iterator) + +Terminals trigger iteration and return a value. They're the "go" button — nothing executes until a terminal is called. + +| Method | Returns | Memory | Live behaviour | +|-----------------|---------------------|--------------------|-----------------------------------------| +| `.fetch()` | `list[Observation]` | Grows with results | TypeError without `.limit()` first | +| `.drain()` | `int` (count) | Constant | Blocks forever, memory stays flat | +| `.save(target)` | target `Stream` | Constant | Blocks forever, appends each to store | +| `.first()` | `Observation` | Constant | Returns first item, then stops | +| `.exists()` | `bool` | Constant | Returns after one item check | +| `.last()` | `Observation` | Materializes | TypeError (uses order_by internally) | +| `.count()` | `int` | Constant | TypeError on transform streams | + +### Choosing the right terminal + +**Batch query** — collect results into memory: +```python +results = stream.after(t).search(vec, k=10).fetch() +``` + +**Live ingestion** — process forever, constant memory: +```python +# Embed and store continuously +stream.live().transform(EmbedImages(clip)).save(target) + +# Side-effect pipeline (no storage) +stream.live().transform(process).drain() +``` + +**One-shot** — get a single observation: +```python +obs = stream.live().transform(xf).first() # blocks until one arrives +has_data = stream.exists() # quick check +``` + +**Bounded live** — collect a fixed number from a live stream: +```python +batch = stream.live().limit(100).fetch() # OK — limit makes it finite +``` + +### Error summary + +All operations that would silently hang on live streams raise `TypeError` instead: + +| Pattern | Error | +|-------------------------------------|-----------------------------------------------| +| `live.transform(xf).search(vec, k)` | `.search() requires finite data` | +| `live.transform(xf).order_by("ts")` | `.order_by() requires finite data` | +| `live.fetch()` (without `.limit()`) | `.fetch() would collect forever` | +| `live.transform(xf).count()` | `.count() would block forever` | +| `live.transform(xf).last()` | `.order_by() requires finite data` (via last) | diff --git a/dimos/memory2/test_blobstore_integration.py b/dimos/memory2/test_blobstore_integration.py new file mode 100644 index 0000000000..6c26a635c0 --- /dev/null +++ b/dimos/memory2/test_blobstore_integration.py @@ -0,0 +1,161 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BlobStore integration with MemoryStore/Backend.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.type.observation import _UNLOADED +from dimos.models.embedding.base import Embedding + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + +def _emb(vec: list[float]) -> Embedding: + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +@pytest.fixture +def bs(tmp_path: Path) -> Iterator[FileBlobStore]: + blob_store = FileBlobStore(root=str(tmp_path / "blobs")) + blob_store.start() + yield blob_store + blob_store.stop() + + +@pytest.fixture +def store(bs: FileBlobStore) -> Iterator[MemoryStore]: + with MemoryStore(blob_store=bs) as s: + yield s + + +class TestBlobStoreIntegration: + def test_append_stores_in_blobstore(self, bs: FileBlobStore, store: MemoryStore) -> None: + s = store.stream("data", bytes) + s.append(b"hello", ts=1.0) + + # Blob was written to the file store + raw = bs.get("data", 0) + assert len(raw) > 0 + + def test_lazy_data_not_loaded_until_access(self, store: MemoryStore) -> None: + s = store.stream("data", str) + obs = s.append("payload", ts=1.0) + + # Data replaced with sentinel after append + assert isinstance(obs._data, type(_UNLOADED)) + assert obs._loader is not None + + def test_lazy_data_loads_correctly(self, store: MemoryStore) -> None: + s = store.stream("data", str) + s.append("payload", ts=1.0) + + result = s.first() + assert result.data == "payload" + + def test_eager_preloads_data(self, bs: FileBlobStore) -> None: + with MemoryStore(blob_store=bs, eager_blobs=True) as store: + s = store.stream("data", str) + s.append("payload", ts=1.0) + + # Iterating with eager_blobs triggers load + results = s.fetch() + assert len(results) == 1 + # Data should be loaded (not _UNLOADED) + assert not isinstance(results[0]._data, type(_UNLOADED)) + assert results[0].data == "payload" + + def test_per_stream_eager_override(self, store: MemoryStore) -> None: + # Default: lazy + lazy_stream = store.stream("lazy", str) + lazy_stream.append("lazy-val", ts=1.0) + + # Override: eager + eager_stream = store.stream("eager", str, eager_blobs=True) + eager_stream.append("eager-val", ts=1.0) + + lazy_results = lazy_stream.fetch() + eager_results = eager_stream.fetch() + + # Lazy: data stays unloaded until accessed + assert lazy_results[0].data == "lazy-val" + + # Eager: data pre-loaded during iteration + assert not isinstance(eager_results[0]._data, type(_UNLOADED)) + assert eager_results[0].data == "eager-val" + + def test_no_blobstore_unchanged(self) -> None: + with MemoryStore() as store: + s = store.stream("data", str) + obs = s.append("inline", ts=1.0) + + # Without blob store, data stays inline + assert obs._data == "inline" + assert obs._loader is None + assert obs.data == "inline" + + def test_blobstore_with_vector_search(self, bs: FileBlobStore) -> None: + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(blob_store=bs, vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) + + # Vector search triggers lazy load via obs.derive(data=obs.data, ...) + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_blobstore_with_text_search(self, store: MemoryStore) -> None: + s = store.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("temperature ok", ts=2.0) + + # Text search triggers lazy load via str(obs.data) + results = s.search_text("motor").fetch() + assert len(results) == 1 + assert results[0].data == "motor fault" + + def test_multiple_appends_get_unique_blobs(self, store: MemoryStore) -> None: + s = store.stream("multi", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + s.append("third", ts=3.0) + + results = s.fetch() + assert [r.data for r in results] == ["first", "second", "third"] + + def test_fetch_preserves_metadata(self, store: MemoryStore) -> None: + s = store.stream("meta", str) + s.append("val", ts=42.0, tags={"kind": "info"}) + + result = s.first() + assert result.ts == 42.0 + assert result.tags == {"kind": "info"} + assert result.data == "val" diff --git a/dimos/memory2/test_buffer.py b/dimos/memory2/test_buffer.py new file mode 100644 index 0000000000..f851a6fcee --- /dev/null +++ b/dimos/memory2/test_buffer.py @@ -0,0 +1,86 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for backpressure buffers.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from dimos.memory2.buffer import Bounded, ClosedError, DropNew, KeepLast, Unbounded + + +class TestBackpressureBuffers: + """Thread-safe buffers bridging push sources to pull consumers.""" + + def test_keep_last_overwrites(self): + buf = KeepLast[int]() + buf.put(1) + buf.put(2) + buf.put(3) + assert buf.take() == 3 + assert len(buf) == 0 + + def test_bounded_drops_oldest(self): + buf = Bounded[int](maxlen=2) + buf.put(1) + buf.put(2) + buf.put(3) # drops 1 + assert buf.take() == 2 + assert buf.take() == 3 + + def test_drop_new_rejects(self): + buf = DropNew[int](maxlen=2) + assert buf.put(1) is True + assert buf.put(2) is True + assert buf.put(3) is False # rejected + assert buf.take() == 1 + assert buf.take() == 2 + + def test_unbounded_keeps_all(self): + buf = Unbounded[int]() + for i in range(100): + buf.put(i) + assert len(buf) == 100 + + def test_close_signals_end(self): + buf = KeepLast[int]() + buf.close() + with pytest.raises(ClosedError): + buf.take() + + def test_buffer_is_iterable(self): + """Iterating a buffer yields items until closed.""" + buf = Unbounded[int]() + buf.put(1) + buf.put(2) + buf.close() + assert list(buf) == [1, 2] + + def test_take_blocks_until_put(self): + buf = KeepLast[int]() + result = [] + + def producer(): + time.sleep(0.05) + buf.put(42) + + t = threading.Thread(target=producer) + t.start() + result.append(buf.take(timeout=2.0)) + t.join() + assert result == [42] diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py new file mode 100644 index 0000000000..5b1f0af767 --- /dev/null +++ b/dimos/memory2/test_e2e.py @@ -0,0 +1,256 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""E2E test: import legacy pickle replays into memory2 SqliteStore.""" + +from __future__ import annotations + +import bisect +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.store.sqlite import SqliteStore +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data_dir +from dimos.utils.testing.replay import TimedSensorReplay + +if TYPE_CHECKING: + from collections.abc import Iterator + +DB_PATH = get_data_dir() / "go2_bigoffice.db" + + +@pytest.fixture(scope="module") +def session() -> Iterator[SqliteStore]: + store = SqliteStore(path=str(DB_PATH)) + with store: + yield store + store.stop() + + +class PoseIndex: + """Preloaded odom data with O(log n) closest-timestamp lookup.""" + + def __init__(self, replay: TimedSensorReplay) -> None: # type: ignore[type-arg] + self._timestamps: list[float] = [] + self._data: list[Any] = [] + for ts, data in replay.iterate_ts(): + self._timestamps.append(ts) + self._data.append(data) + + def find_closest(self, ts: float) -> Any | None: + if not self._timestamps: + return None + idx = bisect.bisect_left(self._timestamps, ts) + # Compare the two candidates around the insertion point + if idx == 0: + return self._data[0] + if idx >= len(self._timestamps): + return self._data[-1] + if ts - self._timestamps[idx - 1] <= self._timestamps[idx] - ts: + return self._data[idx - 1] + return self._data[idx] + + +@pytest.fixture(scope="module") +def video_replay() -> TimedSensorReplay: + return TimedSensorReplay("unitree_go2_bigoffice/video") + + +@pytest.fixture(scope="module") +def odom_index() -> PoseIndex: + return PoseIndex(TimedSensorReplay("unitree_go2_bigoffice/odom")) + + +@pytest.fixture(scope="module") +def lidar_replay() -> TimedSensorReplay: + return TimedSensorReplay("unitree_go2_bigoffice/lidar") + + +@pytest.mark.tool +class TestImportReplay: + """Import legacy pickle replay data into a memory2 SqliteStore.""" + + def test_import_video( + self, + session: SqliteStore, + video_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + with session.stream("color_image", Image) as video: + count = 0 + for ts, frame in video_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print("import", frame) + video.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert video.count() == count + print(f"Imported {count} video frames") + + def test_import_lidar( + self, + session: SqliteStore, + lidar_replay: TimedSensorReplay, # type: ignore[type-arg] + odom_index: PoseIndex, + ) -> None: + # can also be explicit here + # lidar = session.stream("lidar", PointCloud2, codec=Lz4Codec(LcmCodec(PointCloud2))) + lidar = session.stream("lidar", PointCloud2, codec="lz4+lcm") + + count = 0 + for ts, frame in lidar_replay.iterate_ts(): + pose = odom_index.find_closest(ts) + print("import", frame) + lidar.append(frame, ts=ts, pose=pose) + count += 1 + + assert count > 0 + assert lidar.count() == count + print(f"Imported {count} lidar frames") + + def test_query_imported_data(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + assert video.exists() + assert lidar.exists() + + first_frame = video.first() + last_frame = video.last() + assert first_frame.ts < last_frame.ts + + mid_ts = (first_frame.ts + last_frame.ts) / 2 + subset = video.time_range(first_frame.ts, mid_ts).fetch() + assert 0 < len(subset) < video.count() + + streams = session.list_streams() + assert "color_image" in streams + assert "lidar" in streams + + +@pytest.mark.tool +class TestE2EQuery: + """Query operations against real robot replay data.""" + + def test_list_streams(self, session: SqliteStore) -> None: + streams = session.list_streams() + print(streams) + + assert "color_image" in streams + assert "lidar" in streams + assert session.streams.color_image + assert session.streams.lidar + + print(session.streams.lidar) + + def test_video_count(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + assert video.count() > 1000 + + def test_lidar_count(self, session: SqliteStore) -> None: + lidar = session.stream("lidar", PointCloud2) + assert lidar.count() > 1000 + + def test_first_last_timestamps(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + first = video.first() + last = video.last() + assert first.ts < last.ts + duration = last.ts - first.ts + assert duration > 10.0 # at least 10s of data + + def test_time_range_filter(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + first = video.first() + + # Grab first 5 seconds + window = video.time_range(first.ts, first.ts + 5.0).fetch() + assert len(window) > 0 + assert len(window) < video.count() + assert all(first.ts <= obs.ts <= first.ts + 5.0 for obs in window) + + def test_limit_offset_pagination(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + page1 = video.limit(10).fetch() + page2 = video.offset(10).limit(10).fetch() + + assert len(page1) == 10 + assert len(page2) == 10 + assert page1[-1].ts < page2[0].ts # no overlap + + def test_order_by_desc(self, session: SqliteStore) -> None: + video = session.stream("color_image", Image) + last_10 = video.order_by("ts", desc=True).limit(10).fetch() + + assert len(last_10) == 10 + assert all(last_10[i].ts >= last_10[i + 1].ts for i in range(9)) + + def test_lazy_data_loads_correctly(self, session: SqliteStore) -> None: + """Verify lazy blob loading returns valid Image data.""" + from dimos.memory2.type.observation import _Unloaded + + video = session.stream("color_image", Image) + obs = next(iter(video.limit(1))) + + # Should start lazy + assert isinstance(obs._data, _Unloaded) + + # Trigger load + frame = obs.data + assert isinstance(frame, Image) + assert frame.width > 0 + assert frame.height > 0 + + def test_iterate_window_decodes_all(self, session: SqliteStore) -> None: + """Iterate a time window and verify every frame decodes.""" + video = session.stream("color_image", Image) + first_ts = video.first().ts + + window = video.time_range(first_ts, first_ts + 2.0) + count = 0 + for obs in window: + frame = obs.data + assert isinstance(frame, Image) + count += 1 + assert count > 0 + + def test_lidar_data_loads(self, session: SqliteStore) -> None: + """Verify lidar blobs decode to PointCloud2.""" + lidar = session.stream("lidar", PointCloud2) + frame = lidar.first().data + assert isinstance(frame, PointCloud2) + + def test_poses_present(self, session: SqliteStore) -> None: + """Verify poses were stored during import.""" + video = session.stream("color_image", Image) + obs = video.first() + assert obs.pose is not None + + def test_cross_stream_time_alignment(self, session: SqliteStore) -> None: + """Video and lidar should overlap in time.""" + video = session.stream("color_image", Image) + lidar = session.stream("lidar", PointCloud2) + + v_first, v_last = video.first().ts, video.last().ts + l_first, l_last = lidar.first().ts, lidar.last().ts + + # Overlap: max of starts < min of ends + overlap_start = max(v_first, l_first) + overlap_end = min(v_last, l_last) + assert overlap_start < overlap_end, "Video and lidar should overlap in time" + assert overlap_start < overlap_end, "Video and lidar should overlap in time" diff --git a/dimos/memory2/test_e2e_processing.py b/dimos/memory2/test_e2e_processing.py new file mode 100644 index 0000000000..81eba5c2a8 --- /dev/null +++ b/dimos/memory2/test_e2e_processing.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + + +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py new file mode 100644 index 0000000000..57d66da278 --- /dev/null +++ b/dimos/memory2/test_embedding.py @@ -0,0 +1,396 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding layer: EmbeddedObservation, vector search, text search, transformers.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.models.embedding.base import Embedding + + +def _emb(vec: list[float]) -> Embedding: + """Return a unit-normalized Embedding.""" + v = np.array(vec, dtype=np.float32) + v /= np.linalg.norm(v) + 1e-10 + return Embedding(vector=v) + + +class TestEmbeddedObservation: + def test_construction(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="hello", embedding=emb) + assert obs.data == "hello" + assert obs.embedding is emb + assert obs.similarity is None + + def test_is_observation(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="x", embedding=_emb([1, 0])) + assert isinstance(obs, Observation) + + def test_derive_preserves_embedding(self) -> None: + emb = _emb([1, 0, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=emb) + derived = obs.derive(data="b") + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + assert derived.data == "b" + + def test_derive_replaces_embedding(self) -> None: + old = _emb([1, 0, 0]) + new = _emb([0, 1, 0]) + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=old) + derived = obs.derive(data="a", embedding=new) + assert derived.embedding is new + + def test_derive_preserves_similarity(self) -> None: + obs = EmbeddedObservation(id=0, ts=1.0, _data="a", embedding=_emb([1, 0]), similarity=0.95) + derived = obs.derive(data="b") + assert derived.similarity == 0.95 + + def test_observation_derive_promotes_to_embedded(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + emb = _emb([1, 0, 0]) + derived = obs.derive(data="plain", embedding=emb) + assert isinstance(derived, EmbeddedObservation) + assert derived.embedding is emb + + def test_observation_derive_without_embedding_stays_observation(self) -> None: + obs = Observation(id=0, ts=1.0, _data="plain") + derived = obs.derive(data="still plain") + assert type(derived) is Observation + + +class TestListBackendEmbedding: + def test_append_with_embedding(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + emb = _emb([1, 0, 0]) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_append_without_embedding(self, memory_store) -> None: + s = memory_store.stream("plain", str) + obs = s.append("hello") + assert type(obs) is Observation + + def test_search_returns_top_k(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_sorted_by_similarity(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("far", embedding=_emb([0, -1, 0])) + s.append("close", embedding=_emb([0.9, 0.1, 0])) + s.append("exact", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=3).fetch() + assert results[0].data == "exact" + assert results[1].data == "close" + assert results[2].data == "far" + # Descending similarity + assert results[0].similarity >= results[1].similarity >= results[2].similarity + + def test_search_skips_non_embedded(self, memory_store) -> None: + s = memory_store.stream("mixed", str) + s.append("plain") # no embedding + s.append("embedded", embedding=_emb([1, 0, 0])) + + results = s.search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "embedded" + + def test_search_with_filters(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Only the late one should pass the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_search_with_limit(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + for i in range(10): + s.append(f"item{i}", embedding=_emb([1, 0, 0])) + + # search k=5 then limit 2 + results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + assert len(results) == 2 + + def test_search_with_live_raises(self, memory_store) -> None: + s = memory_store.stream("vecs", str) + s.append("x", embedding=_emb([1, 0, 0])) + with pytest.raises(TypeError, match="Cannot combine"): + list(s.live().search(_emb([1, 0, 0]), k=5)) + + +class TestTextSearch: + def test_search_text_substring(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("motor fault detected") + s.append("temperature normal") + s.append("motor overheating") + + results = s.search_text("motor").fetch() + assert len(results) == 2 + assert {r.data for r in results} == {"motor fault detected", "motor overheating"} + + def test_search_text_case_insensitive(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("Motor Fault") + s.append("other event") + + results = s.search_text("motor fault").fetch() + assert len(results) == 1 + + def test_search_text_with_filters(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("motor fault", ts=10.0) + s.append("motor warning", ts=20.0) + s.append("motor fault", ts=30.0) + + results = s.after(15.0).search_text("fault").fetch() + assert len(results) == 1 + assert results[0].ts == 30.0 + + def test_search_text_no_match(self, memory_store) -> None: + s = memory_store.stream("logs", str) + s.append("all clear") + + results = s.search_text("motor").fetch() + assert len(results) == 0 + + +class TestSaveEmbeddings: + def test_save_preserves_embeddings(self, memory_store) -> None: + src = memory_store.stream("source", str) + dst = memory_store.stream("dest", str) + + emb = _emb([1, 0, 0]) + src.append("item", embedding=emb) + src.save(dst) + + results = dst.fetch() + assert len(results) == 1 + assert isinstance(results[0], EmbeddedObservation) + # Same vector content (different Embedding instance after re-append) + np.testing.assert_array_almost_equal(results[0].embedding.to_numpy(), emb.to_numpy()) + + def test_save_mixed_embedded_and_plain(self, memory_store) -> None: + src = memory_store.stream("source", str) + dst = memory_store.stream("dest", str) + + src.append("plain") + src.append("embedded", embedding=_emb([0, 1, 0])) + src.save(dst) + + results = dst.fetch() + assert len(results) == 2 + assert type(results[0]) is Observation + assert isinstance(results[1], EmbeddedObservation) + + +class _MockEmbeddingModel: + """Fake EmbeddingModel that returns deterministic unit vectors.""" + + device = "cpu" + + def embed(self, *images): + vecs = [] + for img in images: + rng = np.random.default_rng(hash(str(img)) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + def embed_text(self, *texts): + vecs = [] + for text in texts: + rng = np.random.default_rng(hash(text) % 2**32) + v = rng.standard_normal(8).astype(np.float32) + v /= np.linalg.norm(v) + vecs.append(Embedding(vector=v)) + return vecs if len(vecs) > 1 else vecs[0] + + +class TestEmbedTransformers: + def test_embed_images_produces_embedded_observations(self, memory_store) -> None: + from dimos.memory2.embed import EmbedImages + + model = _MockEmbeddingModel() + s = memory_store.stream("imgs", str) + s.append("img1", ts=1.0) + s.append("img2", ts=2.0) + + results = s.transform(EmbedImages(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + assert obs.embedding.to_numpy().shape == (8,) + + def test_embed_text_produces_embedded_observations(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + s.append("motor fault", ts=1.0) + s.append("all clear", ts=2.0) + + results = s.transform(EmbedText(model)).fetch() + assert len(results) == 2 + for obs in results: + assert isinstance(obs, EmbeddedObservation) + assert isinstance(obs.embedding, Embedding) + + def test_embed_preserves_data(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + s.append("hello", ts=1.0) + + result = s.transform(EmbedText(model)).first() + assert result.data == "hello" + + def test_embed_then_search(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + model = _MockEmbeddingModel() + s = memory_store.stream("logs", str) + for i in range(10): + s.append(f"log entry {i}", ts=float(i)) + + embedded = s.transform(EmbedText(model)) + # Get the embedding for the first item, then search for similar + first_emb = embedded.first().embedding + results = embedded.search(first_emb, k=3).fetch() + assert len(results) == 3 + # First result should be the exact match + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_embed_batching(self, memory_store) -> None: + from dimos.memory2.embed import EmbedText + + call_sizes: list[int] = [] + + class _TrackingModel(_MockEmbeddingModel): + def embed_text(self, *texts): + call_sizes.append(len(texts)) + return super().embed_text(*texts) + + model = _TrackingModel() + s = memory_store.stream("logs", str) + for i in range(5): + s.append(f"entry {i}") + + list(s.transform(EmbedText(model, batch_size=2))) + # 5 items with batch_size=2 → 3 calls (2, 2, 1) + assert call_sizes == [2, 2, 1] + + +class TestPluggableVectorStore: + """Verify that injecting a VectorStore via store config actually delegates search.""" + + def test_append_stores_in_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("hello", embedding=_emb([1, 0, 0])) + s.append("world", embedding=_emb([0, 1, 0])) + + assert len(vs._vectors["vecs"]) == 2 + + def test_append_without_embedding_skips_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("plain", str) + s.append("no embedding") + + assert "plain" not in vs._vectors + + def test_search_uses_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + s.append("west", embedding=_emb([-1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity is not None + assert results[0].similarity > 0.99 + + def test_search_with_filters_via_vector_store(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs = MemoryVectorStore() + with MemoryStore(vector_store=vs) as store: + s = store.stream("vecs", str) + s.append("early", ts=10.0, embedding=_emb([1, 0, 0])) + s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) + + # Filter + search: only "late" passes the after filter + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + assert len(results) == 1 + assert results[0].data == "late" + + def test_per_stream_vector_store_override(self) -> None: + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.vectorstore.memory import MemoryVectorStore + + vs_default = MemoryVectorStore() + vs_override = MemoryVectorStore() + with MemoryStore(vector_store=vs_default) as store: + # Stream with default vector store + s1 = store.stream("s1", str) + s1.append("a", embedding=_emb([1, 0, 0])) + + # Stream with overridden vector store + s2 = store.stream("s2", str, vector_store=vs_override) + s2.append("b", embedding=_emb([0, 1, 0])) + + assert "s1" in vs_default._vectors + assert "s1" not in vs_override._vectors + assert "s2" in vs_override._vectors + assert "s2" not in vs_default._vectors diff --git a/dimos/memory2/test_registry.py b/dimos/memory2/test_registry.py new file mode 100644 index 0000000000..d611073075 --- /dev/null +++ b/dimos/memory2/test_registry.py @@ -0,0 +1,263 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RegistryStore and serialization round-trips.""" + +from __future__ import annotations + +import pytest + +from dimos.memory2.blobstore.file import FileBlobStore +from dimos.memory2.blobstore.sqlite import SqliteBlobStore, SqliteBlobStoreConfig +from dimos.memory2.notifier.subject import SubjectNotifier +from dimos.memory2.observationstore.sqlite import SqliteObservationStoreConfig +from dimos.memory2.registry import RegistryStore, deserialize_component, qual +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.vectorstore.sqlite import SqliteVectorStore, SqliteVectorStoreConfig + + +class TestQual: + def test_qual_blob_store(self) -> None: + assert qual(SqliteBlobStore) == "dimos.memory2.blobstore.sqlite.SqliteBlobStore" + + def test_qual_file_blob_store(self) -> None: + assert qual(FileBlobStore) == "dimos.memory2.blobstore.file.FileBlobStore" + + def test_qual_vector_store(self) -> None: + assert qual(SqliteVectorStore) == "dimos.memory2.vectorstore.sqlite.SqliteVectorStore" + + def test_qual_notifier(self) -> None: + assert qual(SubjectNotifier) == "dimos.memory2.notifier.subject.SubjectNotifier" + + +class TestRegistryStore: + def test_put_get_round_trip(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + + config = {"payload_module": "builtins.str", "codec_id": "pickle"} + reg.put("my_stream", config) + result = reg.get("my_stream") + assert result == config + conn.close() + + def test_get_missing(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + assert reg.get("nonexistent") is None + conn.close() + + def test_list_streams(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("a", {"x": 1}) + reg.put("b", {"x": 2}) + assert sorted(reg.list_streams()) == ["a", "b"] + conn.close() + + def test_delete(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("x", {"y": 1}) + reg.delete("x") + assert reg.get("x") is None + conn.close() + + def test_upsert(self, tmp_path) -> None: + from dimos.memory2.utils.sqlite import open_sqlite_connection + + conn = open_sqlite_connection(str(tmp_path / "reg.db")) + reg = RegistryStore(conn=conn) + reg.put("x", {"v": 1}) + reg.put("x", {"v": 2}) + assert reg.get("x") == {"v": 2} + conn.close() + + +class TestComponentSerialization: + def test_sqlite_observation_store_config(self) -> None: + cfg = SqliteObservationStoreConfig(page_size=512, path="test.db") + dumped = cfg.model_dump() + restored = SqliteObservationStoreConfig(**dumped) + assert restored.page_size == 512 + + def test_sqlite_blob_store_config(self) -> None: + cfg = SqliteBlobStoreConfig(path="/tmp/test.db") + dumped = cfg.model_dump() + restored = SqliteBlobStoreConfig(**dumped) + assert restored.path == "/tmp/test.db" + + def test_sqlite_blob_store_roundtrip(self, tmp_path) -> None: + store = SqliteBlobStore(path=str(tmp_path / "blob.db")) + data = store.serialize() + assert data["class"] == qual(SqliteBlobStore) + restored = deserialize_component(data) + assert isinstance(restored, SqliteBlobStore) + + def test_file_blob_store_roundtrip(self, tmp_path) -> None: + store = FileBlobStore(root=str(tmp_path / "blobs")) + data = store.serialize() + assert data["class"] == qual(FileBlobStore) + restored = deserialize_component(data) + assert isinstance(restored, FileBlobStore) + assert str(restored._root) == str(tmp_path / "blobs") + + def test_sqlite_vector_store_config(self) -> None: + cfg = SqliteVectorStoreConfig(path="/tmp/vec.db") + dumped = cfg.model_dump() + restored = SqliteVectorStoreConfig(**dumped) + assert restored.path == "/tmp/vec.db" + + def test_sqlite_vector_store_roundtrip(self, tmp_path) -> None: + store = SqliteVectorStore(path=str(tmp_path / "vec.db")) + data = store.serialize() + assert data["class"] == qual(SqliteVectorStore) + restored = deserialize_component(data) + assert isinstance(restored, SqliteVectorStore) + + def test_subject_notifier_roundtrip(self) -> None: + notifier = SubjectNotifier() + data = notifier.serialize() + assert data["class"] == qual(SubjectNotifier) + restored = deserialize_component(data) + assert isinstance(restored, SubjectNotifier) + + def test_deserialize_component(self, tmp_path) -> None: + store = FileBlobStore(root=str(tmp_path / "blobs")) + data = store.serialize() + restored = deserialize_component(data) + assert isinstance(restored, FileBlobStore) + + +class TestBackendSerialization: + def test_backend_serialize(self, tmp_path) -> None: + from dimos.memory2.backend import Backend + from dimos.memory2.codecs.pickle import PickleCodec + from dimos.memory2.observationstore.memory import ListObservationStore + + backend = Backend( + metadata_store=ListObservationStore(name="test"), + codec=PickleCodec(), + blob_store=FileBlobStore(root=str(tmp_path / "blobs")), + notifier=SubjectNotifier(), + ) + data = backend.serialize() + assert data["codec_id"] == "pickle" + assert data["blob_store"]["class"] == qual(FileBlobStore) + assert data["notifier"]["class"] == qual(SubjectNotifier) + + +class TestStoreReopen: + def test_reopen_preserves_data(self, tmp_path) -> None: + """Create a store, write data, close, reopen, read back.""" + db = str(tmp_path / "test.db") + with SqliteStore(path=db) as store: + s = store.stream("nums", int) + s.append(42, ts=1.0) + s.append(99, ts=2.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("nums", int) + assert s2.count() == 2 + obs = s2.fetch() + assert [o.data for o in obs] == [42, 99] + + def test_reopen_preserves_codec(self, tmp_path) -> None: + """Codec ID is stored and restored on reopen.""" + db = str(tmp_path / "codec.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str, codec="pickle") + s.append("hello", ts=1.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("data", str) + assert s2.first().data == "hello" + + def test_reopen_preserves_eager_blobs(self, tmp_path) -> None: + """eager_blobs override is stored in registry and restored on reopen.""" + db = str(tmp_path / "eager.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str, eager_blobs=True) + s.append("test", ts=1.0) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("data") + assert stored is not None + assert stored["eager_blobs"] is True + + def test_reopen_preserves_file_blob_store(self, tmp_path) -> None: + """FileBlobStore override is stored and restored on reopen.""" + db = str(tmp_path / "file_blob.db") + blob_dir = str(tmp_path / "blobs") + with SqliteStore(path=db) as store: + fbs = FileBlobStore(root=blob_dir) + fbs.start() + s = store.stream("imgs", str, blob_store=fbs) + s.append("image_data", ts=1.0) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("imgs") + assert stored is not None + assert stored["blob_store"]["class"] == qual(FileBlobStore) + assert stored["blob_store"]["config"]["root"] == blob_dir + + def test_reopen_type_mismatch_raises(self, tmp_path) -> None: + """Opening a stream with a different payload type raises ValueError.""" + db = str(tmp_path / "mismatch.db") + with SqliteStore(path=db) as store: + store.stream("nums", int) + + with SqliteStore(path=db) as store2: + with pytest.raises(ValueError, match="was created with type"): + store2.stream("nums", str) + + def test_reopen_list_streams(self, tmp_path) -> None: + """list_streams includes streams from registry on reopen.""" + db = str(tmp_path / "list.db") + with SqliteStore(path=db) as store: + store.stream("a", int) + store.stream("b", str) + + with SqliteStore(path=db) as store2: + assert sorted(store2.list_streams()) == ["a", "b"] + + def test_reopen_without_payload_type(self, tmp_path) -> None: + """Reopening a known stream without payload_type works.""" + db = str(tmp_path / "no_type.db") + with SqliteStore(path=db) as store: + s = store.stream("data", str) + s.append("hello", ts=1.0) + + with SqliteStore(path=db) as store2: + s2 = store2.stream("data") + assert s2.first().data == "hello" + + def test_reopen_preserves_page_size(self, tmp_path) -> None: + """page_size is stored in registry and restored on reopen.""" + db = str(tmp_path / "page.db") + with SqliteStore(path=db, page_size=512) as store: + store.stream("data", str) + + with SqliteStore(path=db) as store2: + stored = store2._registry.get("data") + assert stored is not None + assert stored["page_size"] == 512 diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py new file mode 100644 index 0000000000..13ee73d46a --- /dev/null +++ b/dimos/memory2/test_save.py @@ -0,0 +1,123 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Stream.save() and Notifier integration.""" + +from __future__ import annotations + +import pytest + +from dimos.memory2.backend import Backend +from dimos.memory2.codecs.pickle import PickleCodec +from dimos.memory2.notifier.base import Notifier +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.stream import Stream +from dimos.memory2.transform import FnTransformer +from dimos.memory2.type.observation import Observation + + +def _make_backend(name: str = "test") -> Backend[int]: + return Backend(metadata_store=ListObservationStore[int](name=name), codec=PickleCodec()) + + +def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: + backend = _make_backend() + for i in range(n): + backend.append(Observation(id=-1, ts=start_ts + i, _data=i * 10)) + return Stream(source=backend) + + +# ═══════════════════════════════════════════════════════════════════ +# Protocol checks +# ═══════════════════════════════════════════════════════════════════ + + +class TestProtocol: + def test_backend_has_notifier(self) -> None: + b = _make_backend("x") + assert isinstance(b.notifier, Notifier) + + +# ═══════════════════════════════════════════════════════════════════ +# .save() +# ═══════════════════════════════════════════════════════════════════ + + +class TestSave: + def test_save_populates_target(self) -> None: + source = make_stream(3) + target = Stream(source=_make_backend("target")) + + source.save(target) + + results = target.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [0, 10, 20] + + def test_save_returns_target_stream(self) -> None: + source = make_stream(2) + target = Stream(source=_make_backend("target")) + + result = source.save(target) + + assert result is target + + def test_save_preserves_data(self) -> None: + backend = _make_backend("src") + backend.append(Observation(id=-1, ts=1.0, pose=(1, 2, 3), tags={"label": "cat"}, _data=42)) + source = Stream(source=backend) + + target = Stream(source=_make_backend("dst")) + source.save(target) + + obs = target.first() + assert obs.data == 42 + assert obs.ts == 1.0 + assert obs.pose == (1, 2, 3) + assert obs.tags == {"label": "cat"} + + def test_save_with_transform(self) -> None: + source = make_stream(3) # data: 0, 10, 20 + doubled = source.transform(FnTransformer(lambda obs: obs.derive(data=obs.data * 2))) + + target = Stream(source=_make_backend("target")) + doubled.save(target) + + assert [o.data for o in target.fetch()] == [0, 20, 40] + + def test_save_rejects_transform_target(self) -> None: + source = make_stream(2) + base = make_stream(2) + transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) + + with pytest.raises(TypeError, match="Cannot save to a transform stream"): + source.save(transform_stream) + + def test_save_target_queryable(self) -> None: + source = make_stream(5, start_ts=0.0) # ts: 0,1,2,3,4 + + target = Stream(source=_make_backend("target")) + result = source.save(target) + + after_2 = result.after(2.0).fetch() + assert [o.data for o in after_2] == [30, 40] + + def test_save_empty_source(self) -> None: + source = make_stream(0) + target = Stream(source=_make_backend("target")) + + result = source.save(target) + + assert result.count() == 0 + assert result.fetch() == [] diff --git a/dimos/memory2/test_store.py b/dimos/memory2/test_store.py new file mode 100644 index 0000000000..dfba6d6d2b --- /dev/null +++ b/dimos/memory2/test_store.py @@ -0,0 +1,527 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for Store implementations. + +Runs the same test logic against every Store backend (MemoryStore, SqliteStore, ...). +The parametrized ``session`` fixture from conftest runs each test against both backends. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +from dimos.memory2.blobstore.base import BlobStore +from dimos.memory2.vectorstore.base import VectorStore + +if TYPE_CHECKING: + from dimos.memory2.store.base import Store + + +class TestStoreBasic: + """Core store operations that every backend must support.""" + + def test_create_stream_and_append(self, session: Store) -> None: + s = session.stream("images", bytes) + obs = s.append(b"frame1", tags={"camera": "front"}) + + assert obs.data == b"frame1" + assert obs.tags["camera"] == "front" + assert obs.ts > 0 + + def test_append_multiple_and_fetch(self, session: Store) -> None: + s = session.stream("sensor", float) + s.append(1.0, ts=100.0) + s.append(2.0, ts=200.0) + s.append(3.0, ts=300.0) + + results = s.fetch() + assert len(results) == 3 + assert [o.data for o in results] == [1.0, 2.0, 3.0] + + def test_iterate_stream(self, session: Store) -> None: + s = session.stream("log", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + collected = [obs.data for obs in s] + assert collected == ["a", "b"] + + def test_count(self, session: Store) -> None: + s = session.stream("events", str) + assert s.count() == 0 + s.append("x") + s.append("y") + assert s.count() == 2 + + def test_first_and_last(self, session: Store) -> None: + s = session.stream("data", int) + s.append(10, ts=1.0) + s.append(20, ts=2.0) + s.append(30, ts=3.0) + + assert s.first().data == 10 + assert s.last().data == 30 + + def test_first_empty_raises(self, session: Store) -> None: + s = session.stream("empty", int) + with pytest.raises(LookupError): + s.first() + + def test_exists(self, session: Store) -> None: + s = session.stream("check", str) + assert not s.exists() + s.append("hi") + assert s.exists() + + def test_filter_after(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.after(15.0).fetch() + assert [o.data for o in results] == [2, 3] + + def test_filter_before(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.before(25.0).fetch() + assert [o.data for o in results] == [1, 2] + + def test_filter_time_range(self, session: Store) -> None: + s = session.stream("ts_data", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.time_range(15.0, 25.0).fetch() + assert [o.data for o in results] == [2] + + def test_filter_tags(self, session: Store) -> None: + s = session.stream("tagged", str) + s.append("a", tags={"kind": "info"}) + s.append("b", tags={"kind": "error"}) + s.append("c", tags={"kind": "info"}) + + results = s.tags(kind="info").fetch() + assert [o.data for o in results] == ["a", "c"] + + def test_limit_and_offset(self, session: Store) -> None: + s = session.stream("paged", int) + for i in range(5): + s.append(i, ts=float(i)) + + page = s.offset(1).limit(2).fetch() + assert [o.data for o in page] == [1, 2] + + def test_order_by_desc(self, session: Store) -> None: + s = session.stream("ordered", int) + s.append(1, ts=10.0) + s.append(2, ts=20.0) + s.append(3, ts=30.0) + + results = s.order_by("ts", desc=True).fetch() + assert [o.data for o in results] == [3, 2, 1] + + def test_separate_streams_isolated(self, session: Store) -> None: + a = session.stream("stream_a", str) + b = session.stream("stream_b", str) + + a.append("in_a") + b.append("in_b") + + assert [o.data for o in a] == ["in_a"] + assert [o.data for o in b] == ["in_b"] + + def test_same_stream_on_repeated_calls(self, session: Store) -> None: + s1 = session.stream("reuse", str) + s2 = session.stream("reuse", str) + assert s1 is s2 + + def test_append_with_embedding(self, session: Store) -> None: + import numpy as np + + from dimos.memory2.type.observation import EmbeddedObservation + from dimos.models.embedding.base import Embedding + + s = session.stream("vectors", str) + emb = Embedding(vector=np.array([1.0, 0.0, 0.0], dtype=np.float32)) + obs = s.append("hello", embedding=emb) + assert isinstance(obs, EmbeddedObservation) + assert obs.embedding is emb + + def test_search_top_k(self, session: Store) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + s = session.stream("searchable", str) + s.append("north", embedding=_emb([0, 1, 0])) + s.append("east", embedding=_emb([1, 0, 0])) + s.append("south", embedding=_emb([0, -1, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(results) == 2 + assert results[0].data == "north" + assert results[0].similarity > 0.99 + + def test_search_text(self, session: Store) -> None: + s = session.stream("logs", str) + s.append("motor fault") + s.append("temperature ok") + + # SqliteObservationStore blocks search_text to prevent full table scans + try: + results = s.search_text("motor").fetch() + except NotImplementedError: + pytest.skip("search_text not supported on this backend") + assert len(results) == 1 + assert results[0].data == "motor fault" + + +class TestBlobLoading: + """Verify lazy and eager blob loading paths.""" + + def test_sqlite_lazy_by_default(self, sqlite_store: Store) -> None: + """Default sqlite iteration uses lazy loaders — data is _UNLOADED until accessed.""" + from dimos.memory2.type.observation import _Unloaded + + s = sqlite_store.stream("lazy_test", str) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Before accessing .data, _data should be the unloaded sentinel + assert isinstance(obs._data, _Unloaded) + assert obs._loader is not None + # Accessing .data triggers the loader + val = obs.data + assert isinstance(val, str) + # After loading, _loader is cleared + assert obs._loader is None + + def test_sqlite_eager_loads_inline(self, sqlite_store: Store) -> None: + """With eager_blobs=True, data is loaded via JOIN — no lazy loader.""" + from dimos.memory2.type.observation import _Unloaded + + s = sqlite_store.stream("eager_test", str, eager_blobs=True) + s.append("hello", ts=1.0) + s.append("world", ts=2.0) + + for obs in s: + # Data should already be loaded — no lazy sentinel + assert not isinstance(obs._data, _Unloaded) + assert obs._loader is None + assert isinstance(obs.data, str) + + def test_sqlite_lazy_and_eager_same_values(self, sqlite_store: Store) -> None: + """Both paths must return identical data.""" + lazy_s = sqlite_store.stream("vals", str) + lazy_s.append("alpha", ts=1.0, tags={"k": "v"}) + lazy_s.append("beta", ts=2.0, tags={"k": "w"}) + + # Lazy read + lazy_results = lazy_s.fetch() + + # Eager read — new stream handle with eager_blobs on same backend + eager_s = sqlite_store.stream("vals", str, eager_blobs=True) + eager_results = eager_s.fetch() + + assert [o.data for o in lazy_results] == [o.data for o in eager_results] + assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] + assert [o.ts for o in lazy_results] == [o.ts for o in eager_results] + + def test_memory_lazy_with_blobstore(self, tmp_path) -> None: + """MemoryStore with a BlobStore uses lazy loaders.""" + from dimos.memory2.blobstore.file import FileBlobStore + from dimos.memory2.store.memory import MemoryStore + from dimos.memory2.type.observation import _Unloaded + + bs = FileBlobStore(root=str(tmp_path / "blobs")) + bs.start() + with MemoryStore(blob_store=bs) as store: + s = store.stream("mem_lazy", str) + s.append("data1", ts=1.0) + + obs = s.first() + # Backend replaces _data with _UNLOADED when blob_store is set + assert isinstance(obs._data, _Unloaded) + assert obs.data == "data1" + bs.stop() + + +class SpyBlobStore(BlobStore): + """BlobStore that records all calls for verification.""" + + def __init__(self) -> None: + super().__init__() + self.puts: list[tuple[str, int, bytes]] = [] + self.gets: list[tuple[str, int]] = [] + self.store: dict[tuple[str, int], bytes] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, data: bytes) -> None: + self.puts.append((stream, key, data)) + self.store[(stream, key)] = data + + def get(self, stream: str, key: int) -> bytes: + self.gets.append((stream, key)) + return self.store[(stream, key)] + + def delete(self, stream: str, key: int) -> None: + self.store.pop((stream, key), None) + + +class SpyVectorStore(VectorStore): + """VectorStore that records all calls for verification.""" + + def __init__(self) -> None: + super().__init__() + self.puts: list[tuple[str, int]] = [] + self.searches: list[tuple[str, int]] = [] + self.vectors: dict[str, dict[int, Any]] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream: str, key: int, embedding: Any) -> None: + self.puts.append((stream, key)) + self.vectors.setdefault(stream, {})[key] = embedding + + def search(self, stream: str, query: Any, k: int) -> list[tuple[int, float]]: + self.searches.append((stream, k)) + vectors = self.vectors.get(stream, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream: str, key: int) -> None: + self.vectors.get(stream, {}).pop(key, None) + + +@pytest.fixture +def memory_spy_session(): + from dimos.memory2.store.memory import MemoryStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + with MemoryStore(blob_store=blob_spy, vector_store=vec_spy) as store: + yield store, blob_spy, vec_spy + + +@pytest.fixture +def sqlite_spy_session(tmp_path): + from dimos.memory2.store.sqlite import SqliteStore + + blob_spy = SpyBlobStore() + vec_spy = SpyVectorStore() + with SqliteStore( + path=str(tmp_path / "spy.db"), blob_store=blob_spy, vector_store=vec_spy + ) as store: + yield store, blob_spy, vec_spy + + +@pytest.fixture(params=["memory_spy_session", "sqlite_spy_session"]) +def spy_session(request: pytest.FixtureRequest): + return request.getfixturevalue(request.param) + + +class TestStoreDelegation: + """Verify all backends delegate to pluggable BlobStore and VectorStore.""" + + def test_append_calls_blob_put(self, spy_session) -> None: + store, blob_spy, _vec_spy = spy_session + s = store.stream("blobs", str) + s.append("first", ts=1.0) + s.append("second", ts=2.0) + + assert len(blob_spy.puts) == 2 + assert all(stream == "blobs" for stream, _k, _d in blob_spy.puts) + + def test_iterate_calls_blob_get(self, spy_session) -> None: + store, blob_spy, _vec_spy = spy_session + s = store.stream("blobs", str) + s.append("a", ts=1.0) + s.append("b", ts=2.0) + + blob_spy.gets.clear() + for obs in s: + _ = obs.data + assert len(blob_spy.gets) == 2 + + def test_append_embedding_calls_vector_put(self, spy_session) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + store, _blob_spy, vec_spy = spy_session + s = store.stream("vecs", str) + s.append("a", ts=1.0, embedding=_emb([1, 0, 0])) + s.append("b", ts=2.0, embedding=_emb([0, 1, 0])) + s.append("c", ts=3.0) # no embedding + + assert len(vec_spy.puts) == 2 + + def test_search_calls_vector_search(self, spy_session) -> None: + import numpy as np + + from dimos.models.embedding.base import Embedding + + def _emb(v: list[float]) -> Embedding: + a = np.array(v, dtype=np.float32) + return Embedding(vector=a / (np.linalg.norm(a) + 1e-10)) + + store, _blob_spy, vec_spy = spy_session + s = store.stream("vecs", str) + s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) + s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) + + results = s.search(_emb([0, 1, 0]), k=2).fetch() + assert len(vec_spy.searches) == 1 + assert results[0].data == "north" + + +class TestStandaloneComponents: + """Verify each SQLite component works standalone with path= (no Store needed).""" + + def test_observation_store_standalone(self, tmp_path) -> None: + from dimos.memory2.codecs.base import codec_for + from dimos.memory2.observationstore.sqlite import SqliteObservationStore + from dimos.memory2.type.filter import StreamQuery + from dimos.memory2.type.observation import Observation + + db = str(tmp_path / "obs.db") + codec = codec_for(str) + with SqliteObservationStore(path=db, name="events", codec=codec) as store: + obs = Observation(id=0, ts=1.0, _data="hello") + row_id = store.insert(obs) + store.commit() + assert row_id == 1 + + results = list(store.query(StreamQuery())) + assert len(results) == 1 + assert results[0].ts == 1.0 + + def test_blob_store_standalone(self, tmp_path) -> None: + from dimos.memory2.blobstore.sqlite import SqliteBlobStore + + db = str(tmp_path / "blob.db") + with SqliteBlobStore(path=db) as store: + store.put("stream1", 1, b"data1") + store.put("stream1", 2, b"data2") + assert store.get("stream1", 1) == b"data1" + assert store.get("stream1", 2) == b"data2" + + def test_vector_store_standalone(self, tmp_path) -> None: + import numpy as np + + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + from dimos.models.embedding.base import Embedding + + db = str(tmp_path / "vec.db") + with SqliteVectorStore(path=db) as store: + emb1 = Embedding(vector=np.array([1, 0, 0], dtype=np.float32)) + emb2 = Embedding(vector=np.array([0, 1, 0], dtype=np.float32)) + store.put("vecs", 1, emb1) + store.put("vecs", 2, emb2) + + results = store.search("vecs", emb1, k=2) + assert len(results) == 2 + assert results[0][0] == 1 # closest to emb1 is itself + + def test_conn_and_path_mutually_exclusive(self, tmp_path) -> None: + import sqlite3 + + from dimos.memory2.blobstore.sqlite import SqliteBlobStore + from dimos.memory2.observationstore.sqlite import SqliteObservationStore + from dimos.memory2.vectorstore.sqlite import SqliteVectorStore + + conn = sqlite3.connect(":memory:") + db = str(tmp_path / "test.db") + + with pytest.raises(ValueError, match="either conn or path"): + SqliteBlobStore(conn=conn, path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteVectorStore(conn=conn, path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteObservationStore(conn=conn, name="x", path=db) + with pytest.raises(ValueError, match="either conn or path"): + SqliteBlobStore() + with pytest.raises(ValueError, match="either conn or path"): + SqliteVectorStore() + with pytest.raises(ValueError, match="either conn or path"): + SqliteObservationStore(name="x") + conn.close() + + +class TestStreamAccessor: + """Test attribute-style stream access via store.streams.""" + + def test_accessor_returns_same_stream(self, session: Store) -> None: + s = session.stream("images", bytes) + assert session.streams.images is s + + def test_accessor_dir_lists_streams(self, session: Store) -> None: + session.stream("alpha", str) + session.stream("beta", int) + names = dir(session.streams) + assert "alpha" in names + assert "beta" in names + + def test_accessor_missing_raises(self, session: Store) -> None: + with pytest.raises(AttributeError, match="nonexistent"): + _ = session.streams.nonexistent + + def test_accessor_getitem(self, session: Store) -> None: + s = session.stream("data", float) + assert session.streams["data"] is s + + def test_accessor_getitem_missing_raises(self, session: Store) -> None: + with pytest.raises(KeyError): + session.streams["nope"] + + def test_accessor_repr(self, session: Store) -> None: + session.stream("x", str) + r = repr(session.streams) + assert "x" in r + assert "StreamAccessor" in r + + def test_accessor_dynamic(self, session: Store) -> None: + assert "late" not in dir(session.streams) + session.stream("late", str) + assert "late" in dir(session.streams) diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py new file mode 100644 index 0000000000..03c3caec76 --- /dev/null +++ b/dimos/memory2/test_stream.py @@ -0,0 +1,728 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""memory stream tests — serves as living documentation of the lazy stream API. + +Each test demonstrates a specific capability with clear setup, action, and assertion. +""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +import pytest + +from dimos.memory2.buffer import KeepLast, Unbounded +from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer +from dimos.memory2.type.observation import Observation + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.memory2.stream import Stream + + +@pytest.fixture +def make_stream(session) -> Callable[..., Stream[int]]: + stream_index = 0 + + def f(n: int = 5, start_ts: float = 0.0): + nonlocal stream_index + stream_index += 1 + stream = session.stream(f"test{stream_index}", int) + for i in range(n): + stream.append(i * 10, ts=start_ts + i) + return stream + + return f + + +# ═══════════════════════════════════════════════════════════════════ +# 1. Basic iteration +# ═══════════════════════════════════════════════════════════════════ + + +class TestBasicIteration: + """Streams are lazy iterables — nothing runs until you iterate.""" + + def test_iterate_yields_all_observations(self, make_stream): + stream = make_stream(5) + obs = list(stream) + assert len(obs) == 5 + assert [o.data for o in obs] == [0, 10, 20, 30, 40] + + def test_iterate_preserves_timestamps(self, make_stream): + stream = make_stream(3, start_ts=100.0) + assert [o.ts for o in stream] == [100.0, 101.0, 102.0] + + def test_empty_stream(self, make_stream): + stream = make_stream(0) + assert list(stream) == [] + + def test_fetch_materializes_to_list(self, make_stream): + result = make_stream(3).fetch() + assert isinstance(result, list) + assert len(result) == 3 + + def test_stream_is_reiterable(self, make_stream): + """Same stream can be iterated multiple times — each time re-queries.""" + stream = make_stream(3) + first = [o.data for o in stream] + second = [o.data for o in stream] + assert first == second == [0, 10, 20] + + +# ═══════════════════════════════════════════════════════════════════ +# 2. Temporal filters +# ═══════════════════════════════════════════════════════════════════ + + +class TestTemporalFilters: + """Temporal filters constrain observations by timestamp.""" + + def test_after(self, make_stream): + """.after(t) keeps observations with ts > t.""" + result = make_stream(5).after(2.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0] + + def test_before(self, make_stream): + """.before(t) keeps observations with ts < t.""" + result = make_stream(5).before(2.0).fetch() + assert [o.ts for o in result] == [0.0, 1.0] + + def test_time_range(self, make_stream): + """.time_range(t1, t2) keeps t1 <= ts <= t2.""" + result = make_stream(5).time_range(1.0, 3.0).fetch() + assert [o.ts for o in result] == [1.0, 2.0, 3.0] + + def test_at_with_tolerance(self, make_stream): + """.at(t, tolerance) keeps observations within tolerance of t.""" + result = make_stream(5).at(2.0, tolerance=0.5).fetch() + assert [o.ts for o in result] == [2.0] + + def test_chained_temporal_filters(self, make_stream): + """Filters compose — each narrows the result.""" + result = make_stream(10).after(2.0).before(7.0).fetch() + assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] + + +# ═══════════════════════════════════════════════════════════════════ +# 3. Spatial filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestSpatialFilter: + """.near(pose, radius) filters by Euclidean distance.""" + + def test_near_with_tuples(self, memory_session): + stream = memory_session.stream("spatial") + stream.append("origin", ts=0.0, pose=(0, 0, 0)) + stream.append("close", ts=1.0, pose=(1, 1, 0)) + stream.append("far", ts=2.0, pose=(10, 10, 10)) + + result = stream.near((0, 0, 0), radius=2.0).fetch() + assert [o.data for o in result] == ["origin", "close"] + + def test_near_excludes_no_pose(self, memory_session): + stream = memory_session.stream("spatial") + stream.append("no_pose", ts=0.0) + stream.append("has_pose", ts=1.0, pose=(0, 0, 0)) + + result = stream.near((0, 0, 0), radius=10.0).fetch() + assert [o.data for o in result] == ["has_pose"] + + +# ═══════════════════════════════════════════════════════════════════ +# 4. Tags filter +# ═══════════════════════════════════════════════════════════════════ + + +class TestTagsFilter: + """.filter_tags() matches on observation metadata.""" + + def test_filter_by_tag(self, memory_session): + stream = memory_session.stream("tagged") + stream.append("cat", ts=0.0, tags={"type": "animal", "legs": 4}) + stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) + stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) + + result = stream.tags(type="animal").fetch() + assert [o.data for o in result] == ["cat", "dog"] + + def test_filter_multiple_tags(self, memory_session): + stream = memory_session.stream("tagged") + stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) + stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) + + result = stream.tags(x=1, y=2).fetch() + assert [o.data for o in result] == ["a"] + + +# ═══════════════════════════════════════════════════════════════════ +# 5. Ordering, limit, offset +# ═══════════════════════════════════════════════════════════════════ + + +class TestOrderLimitOffset: + def test_limit(self, make_stream): + result = make_stream(10).limit(3).fetch() + assert len(result) == 3 + + def test_offset(self, make_stream): + result = make_stream(5).offset(2).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_limit_and_offset(self, make_stream): + result = make_stream(10).offset(2).limit(3).fetch() + assert [o.data for o in result] == [20, 30, 40] + + def test_order_by_ts_desc(self, make_stream): + result = make_stream(5).order_by("ts", desc=True).fetch() + assert [o.ts for o in result] == [4.0, 3.0, 2.0, 1.0, 0.0] + + def test_first(self, make_stream): + obs = make_stream(5).first() + assert obs.data == 0 + + def test_last(self, make_stream): + obs = make_stream(5).last() + assert obs.data == 40 + + def test_first_empty_raises(self, make_stream): + with pytest.raises(LookupError): + make_stream(0).first() + + def test_count(self, make_stream): + assert make_stream(5).count() == 5 + assert make_stream(5).after(2.0).count() == 2 + + def test_exists(self, make_stream): + assert make_stream(5).exists() + assert not make_stream(0).exists() + assert not make_stream(5).after(100.0).exists() + + def test_drain(self, make_stream): + assert make_stream(5).drain() == 5 + assert make_stream(5).after(2.0).drain() == 2 + assert make_stream(0).drain() == 0 + + +# ═══════════════════════════════════════════════════════════════════ +# 6. Functional API: .filter(), .map() +# ═══════════════════════════════════════════════════════════════════ + + +class TestFunctionalAPI: + """Functional combinators receive the full Observation.""" + + def test_filter_with_predicate(self, make_stream): + """.filter() takes a predicate on the full Observation.""" + result = make_stream(5).filter(lambda obs: obs.data > 20).fetch() + assert [o.data for o in result] == [30, 40] + + def test_filter_on_metadata(self, make_stream): + """Predicates can access ts, tags, pose — not just data.""" + result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).fetch() + assert [o.ts for o in result] == [0.0, 2.0, 4.0] + + def test_map(self, make_stream): + """.map() transforms each observation's data.""" + result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_map_preserves_ts(self, make_stream): + result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).fetch() + assert [o.ts for o in result] == [0.0, 1.0, 2.0] + assert [o.data for o in result] == ["0", "10", "20"] + + +# ═══════════════════════════════════════════════════════════════════ +# 7. Transform chaining +# ═══════════════════════════════════════════════════════════════════ + + +class TestTransformChaining: + """Transforms chain lazily — each obs flows through the full pipeline.""" + + def test_single_transform(self, make_stream): + xf = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + result = make_stream(3).transform(xf).fetch() + assert [o.data for o in result] == [1, 11, 21] + + def test_chained_transforms(self, make_stream): + """stream.transform(A).transform(B) — B pulls from A which pulls from source.""" + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = make_stream(3).transform(add_one).transform(double).fetch() + # (0+1)*2=2, (10+1)*2=22, (20+1)*2=42 + assert [o.data for o in result] == [2, 22, 42] + + def test_transform_can_skip(self, make_stream): + """Returning None from a transformer skips that observation.""" + keep_even = FnTransformer(lambda obs: obs if obs.data % 20 == 0 else None) + result = make_stream(5).transform(keep_even).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_transform_filter_transform(self, memory_session): + """stream.transform(A).near(pose).transform(B) — filter between transforms.""" + stream = memory_session.stream("tfft") + stream.append(1, ts=0.0, pose=(0, 0, 0)) + stream.append(2, ts=1.0, pose=(100, 100, 100)) + stream.append(3, ts=2.0, pose=(1, 0, 0)) + + add_ten = FnTransformer(lambda obs: obs.derive(data=obs.data + 10)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + + result = ( + stream.transform(add_ten) # 11, 12, 13 + .near((0, 0, 0), 5.0) # keeps pose at (0,0,0) and (1,0,0) + .transform(double) # 22, 26 + .fetch() + ) + assert [o.data for o in result] == [22, 26] + + def test_generator_function_transform(self, make_stream): + """A bare generator function works as a transform.""" + + def double_all(upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + result = make_stream(3).transform(double_all).fetch() + assert [o.data for o in result] == [0, 20, 40] + + def test_generator_function_stateful(self, make_stream): + """Generator transforms can accumulate state and yield at their own pace.""" + + def running_sum(upstream): + total = 0 + for obs in upstream: + total += obs.data + yield obs.derive(data=total) + + result = make_stream(3).transform(running_sum).fetch() + # 0, 0+10=10, 10+20=30 + assert [o.data for o in result] == [0, 10, 30] + + def test_quality_window(self, memory_session): + """QualityWindow keeps the best item per time window.""" + stream = memory_session.stream("qw") + # Window 1: ts 0.0-0.9 → best quality + stream.append(0.3, ts=0.0) + stream.append(0.9, ts=0.3) # best in window + stream.append(0.1, ts=0.7) + # Window 2: ts 1.0-1.9 + stream.append(0.5, ts=1.0) + stream.append(0.8, ts=1.5) # best in window + # Window 3: ts 2.0+ (emitted at end via flush) + stream.append(0.6, ts=2.2) + + xf = QualityWindow(quality_fn=lambda v: v, window=1.0) + result = stream.transform(xf).fetch() + assert [o.data for o in result] == [0.9, 0.8, 0.6] + + def test_streaming_not_buffering(self, make_stream): + """Transforms process lazily — early limit stops pulling from source.""" + calls = [] + + class CountingXf(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + calls.append(obs.data) + yield obs + + result = make_stream(100).transform(CountingXf()).limit(3).fetch() + assert len(result) == 3 + # The transformer should have processed at most a few more than 3 + # (not all 100) due to lazy evaluation + assert len(calls) == 3 + + +# ═══════════════════════════════════════════════════════════════════ +# 8. Store +# ═══════════════════════════════════════════════════════════════════ + + +class TestStore: + """Store -> Stream hierarchy for named streams.""" + + def test_basic_store(self, memory_store): + images = memory_store.stream("images") + images.append("frame1", ts=0.0) + images.append("frame2", ts=1.0) + assert images.count() == 2 + + def test_same_stream_on_repeated_calls(self, memory_store): + s1 = memory_store.stream("images") + s2 = memory_store.stream("images") + assert s1 is s2 + + def test_list_streams(self, memory_store): + memory_store.stream("images") + memory_store.stream("lidar") + names = memory_store.list_streams() + assert "images" in names + assert "lidar" in names + assert len(names) == 2 + + def test_delete_stream(self, memory_store): + memory_store.stream("temp") + memory_store.delete_stream("temp") + assert "temp" not in memory_store.list_streams() + + +# ═══════════════════════════════════════════════════════════════════ +# 9. Lazy data loading +# ═══════════════════════════════════════════════════════════════════ + + +class TestLazyData: + """Observation.data supports lazy loading with cleanup.""" + + def test_eager_data(self): + """In-memory observations have data set directly — zero-cost access.""" + obs = Observation(id=0, ts=0.0, _data="hello") + assert obs.data == "hello" + + def test_lazy_loading(self): + """Data loaded on first access, loader released after.""" + load_count = 0 + + def loader(): + nonlocal load_count + load_count += 1 + return "loaded" + + obs = Observation(id=0, ts=0.0, _loader=loader) + assert load_count == 0 + assert obs.data == "loaded" + assert load_count == 1 + assert obs._loader is None # released + assert obs.data == "loaded" # cached, no second load + assert load_count == 1 + + def test_no_data_no_loader_raises(self): + obs = Observation(id=0, ts=0.0) + with pytest.raises(LookupError): + _ = obs.data + + def test_derive_preserves_metadata(self): + obs = Observation(id=42, ts=1.5, pose=(1, 2, 3), tags={"k": "v"}, _data="original") + derived = obs.derive(data="transformed") + assert derived.id == 42 + assert derived.ts == 1.5 + assert derived.pose == (1, 2, 3) + assert derived.tags == {"k": "v"} + assert derived.data == "transformed" + + +# ═══════════════════════════════════════════════════════════════════ +# 10. Live mode +# ═══════════════════════════════════════════════════════════════════ + + +class TestLiveMode: + """Live streams yield backfill then block for new observations.""" + + def test_live_sees_backfill_then_new(self, memory_session): + """Backfill first, then live appends come through.""" + stream = memory_session.stream("live") + stream.append("old", ts=0.0) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("new1", ts=1.0) + stream.append("new2", ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["old", "new1", "new2"] + + def test_live_with_filter(self, memory_session): + """Filters apply to live data — non-matching obs are dropped silently.""" + stream = memory_session.stream("live_filter") + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(1, ts=1.0) # filtered out (ts <= 5.0) + stream.append(2, ts=6.0) # passes + stream.append(3, ts=3.0) # filtered out + stream.append(4, ts=10.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == [2, 4] + + def test_live_deduplicates_backfill_overlap(self, memory_session): + """Observations seen in backfill are not re-yielded from the live buffer.""" + stream = memory_session.stream("dedup") + stream.append("backfill", ts=0.0) + live = stream.live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("live1", ts=1.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + assert results == ["backfill", "live1"] + + def test_live_with_keep_last_backpressure(self, memory_session): + """KeepLast drops intermediate values when consumer is slow.""" + stream = memory_session.stream("bp") + live = stream.live(buffer=KeepLast()) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if obs.data >= 90: + consumed.set() + return + time.sleep(0.1) # slow consumer + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + # Rapid producer — KeepLast will drop most of these + for i in range(100): + stream.append(i, ts=float(i)) + time.sleep(0.001) + + consumed.wait(timeout=5.0) + t.join(timeout=2.0) + # KeepLast means many values were dropped — far fewer than 100 + assert len(results) < 50 + assert results[-1] >= 90 + + def test_live_transform_receives_live_items(self, memory_session): + """Transforms downstream of .live() see both backfill and live items.""" + stream = memory_session.stream("live_xf") + stream.append(1, ts=0.0) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # All items went through the double transform + assert results == [2, 20, 200] + + def test_live_on_transform_raises(self, make_stream): + """Calling .live() on a transform stream raises TypeError.""" + stream = make_stream(3) + xf = FnTransformer(lambda obs: obs) + with pytest.raises(TypeError, match="Cannot call .live"): + stream.transform(xf).live() + + def test_is_live(self, memory_session): + """is_live() walks the source chain to detect live mode.""" + stream = memory_session.stream("is_live") + assert not stream.is_live() + + live = stream.live(buffer=Unbounded()) + assert live.is_live() + + xf = FnTransformer(lambda obs: obs) + transformed = live.transform(xf) + assert transformed.is_live() + + # Two levels deep + double_xf = transformed.transform(xf) + assert double_xf.is_live() + + # Non-live transform is not live + assert not stream.transform(xf).is_live() + + def test_search_on_live_transform_raises(self, memory_session): + """search() on a transform with live upstream raises immediately.""" + stream = memory_session.stream("live_search") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + import numpy as np + + from dimos.models.embedding.base import Embedding + + vec = Embedding(vector=np.array([1.0, 0.0, 0.0])) + with pytest.raises(TypeError, match="requires finite data"): + # Use list() to trigger iteration — fetch() would hit its own guard first + list(live_xf.search(vec, k=5)) + + def test_order_by_on_live_transform_raises(self, memory_session): + """order_by() on a transform with live upstream raises immediately.""" + stream = memory_session.stream("live_order") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + list(live_xf.order_by("ts", desc=True)) + + def test_fetch_on_live_without_limit_raises(self, memory_session): + """fetch() on a live stream without limit() raises TypeError.""" + stream = memory_session.stream("live_fetch") + live = stream.live(buffer=Unbounded()) + + with pytest.raises(TypeError, match="block forever"): + live.fetch() + + def test_fetch_on_live_transform_without_limit_raises(self, memory_session): + """fetch() on a live transform without limit() raises TypeError.""" + stream = memory_session.stream("live_fetch_xf") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="block forever"): + live_xf.fetch() + + def test_count_on_live_transform_raises(self, memory_session): + """count() on a live transform stream raises TypeError.""" + stream = memory_session.stream("live_count") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="block forever"): + live_xf.count() + + def test_last_on_live_transform_raises(self, memory_session): + """last() on a live transform raises TypeError (via order_by guard).""" + stream = memory_session.stream("live_last") + xf = FnTransformer(lambda obs: obs) + live_xf = stream.live(buffer=Unbounded()).transform(xf) + + with pytest.raises(TypeError, match="requires finite data"): + live_xf.last() + + def test_live_chained_transforms(self, memory_session): + """stream.live().transform(A).transform(B) — both transforms applied to live items.""" + stream = memory_session.stream("live_chain") + stream.append(1, ts=0.0) + add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) + double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) + live = stream.live(buffer=Unbounded()).transform(add_one).transform(double) + + results: list[int] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 3: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append(10, ts=1.0) + stream.append(100, ts=2.0) + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # (1+1)*2=4, (10+1)*2=22, (100+1)*2=202 + assert results == [4, 22, 202] + + def test_live_filter_before_live(self, memory_session): + """Filters applied before .live() work on both backfill and live items.""" + stream = memory_session.stream("live_pre_filter") + stream.append("a", ts=1.0) + stream.append("b", ts=10.0) + live = stream.after(5.0).live(buffer=Unbounded()) + + results: list[str] = [] + consumed = threading.Event() + + def consumer(): + for obs in live: + results.append(obs.data) + if len(results) >= 2: + consumed.set() + return + + t = threading.Thread(target=consumer) + t.start() + + time.sleep(0.05) + stream.append("c", ts=3.0) # filtered + stream.append("d", ts=20.0) # passes + + consumed.wait(timeout=2.0) + t.join(timeout=2.0) + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] + # "a" filtered in backfill, "c" filtered in live + assert results == ["b", "d"] + assert results == ["b", "d"] + assert results == ["b", "d"] diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py new file mode 100644 index 0000000000..1e5dc35c2c --- /dev/null +++ b/dimos/memory2/transform.py @@ -0,0 +1,115 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +import inspect +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from dimos.memory2.utils.formatting import FilterRepr + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.memory2.type.observation import Observation + +T = TypeVar("T") +R = TypeVar("R") + + +class Transformer(FilterRepr, ABC, Generic[T, R]): + """Transforms a stream of observations lazily via iterator -> iterator. + + Pull from upstream, yield transformed observations. Naturally supports + batching, windowing, fan-out. The generator cleans + up when upstream exhausts. + """ + + @abstractmethod + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: ... + + def __str__(self) -> str: + parts: list[str] = [] + for name in inspect.signature(self.__init__).parameters: # type: ignore[misc] + for attr in (name, f"_{name}"): + if hasattr(self, attr): + val = getattr(self, attr) + if callable(val): + parts.append(f"{name}={getattr(val, '__name__', '...')}") + else: + parts.append(f"{name}={val!r}") + break + return f"{self.__class__.__name__}({', '.join(parts)})" + + +class FnTransformer(Transformer[T, R]): + """Wraps a callable that receives an Observation and returns a new one (or None to skip).""" + + def __init__(self, fn: Callable[[Observation[T]], Observation[R] | None]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + fn = self._fn + for obs in upstream: + result = fn(obs) + if result is not None: + yield result + + +class FnIterTransformer(Transformer[T, R]): + """Wraps a bare ``Iterator → Iterator`` callable (e.g. a generator function).""" + + def __init__(self, fn: Callable[[Iterator[Observation[T]]], Iterator[Observation[R]]]) -> None: + self._fn = fn + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + return self._fn(upstream) + + +class QualityWindow(Transformer[T, T]): + """Keeps the highest-quality item per time window. + + Emits the best observation when the window advances. The last window + is emitted when the upstream iterator exhausts — no flush needed. + """ + + def __init__(self, quality_fn: Callable[[Any], float], window: float) -> None: + self._quality_fn = quality_fn + self._window = window + + def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + quality_fn = self._quality_fn + window = self._window + best: Observation[T] | None = None + best_score: float = -1.0 + window_start: float | None = None + + for obs in upstream: + if window_start is not None and (obs.ts - window_start) >= window: + if best is not None: + yield best + best = None + best_score = -1.0 + window_start = obs.ts + + score = quality_fn(obs.data) + if score > best_score: + best = obs + best_score = score + if window_start is None: + window_start = obs.ts + + if best is not None: + yield best diff --git a/dimos/memory2/type/filter.py b/dimos/memory2/type/filter.py new file mode 100644 index 0000000000..af453498fd --- /dev/null +++ b/dimos/memory2/type/filter.py @@ -0,0 +1,212 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields +from itertools import islice +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.memory2.buffer import BackpressureBuffer + from dimos.memory2.type.observation import Observation + from dimos.models.embedding.base import Embedding + + +@dataclass(frozen=True) +class Filter(ABC): + """Any object with a .matches(obs) -> bool method can be a filter.""" + + @abstractmethod + def matches(self, obs: Observation[Any]) -> bool: ... + + def __str__(self) -> str: + args = ", ".join(f"{f.name}={getattr(self, f.name)!r}" for f in fields(self)) + return f"{self.__class__.__name__}({args})" + + +@dataclass(frozen=True) +class AfterFilter(Filter): + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts > self.t + + +@dataclass(frozen=True) +class BeforeFilter(Filter): + t: float + + def matches(self, obs: Observation[Any]) -> bool: + return obs.ts < self.t + + +@dataclass(frozen=True) +class TimeRangeFilter(Filter): + t1: float + t2: float + + def matches(self, obs: Observation[Any]) -> bool: + return self.t1 <= obs.ts <= self.t2 + + +@dataclass(frozen=True) +class AtFilter(Filter): + t: float + tolerance: float = 1.0 + + def matches(self, obs: Observation[Any]) -> bool: + return abs(obs.ts - self.t) <= self.tolerance + + +@dataclass(frozen=True) +class NearFilter(Filter): + pose: Any = field(hash=False) + radius: float = 0.0 + + def matches(self, obs: Observation[Any]) -> bool: + if obs.pose is None or self.pose is None: + return False + p1 = self.pose + p2 = obs.pose + # Support both raw (x,y,z) tuples and PoseStamped objects + if hasattr(p1, "position"): + p1 = p1.position + if hasattr(p2, "position"): + p2 = p2.position + x1, y1, z1 = _xyz(p1) + x2, y2, z2 = _xyz(p2) + dist_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 + return dist_sq <= self.radius**2 + + +def _xyz(p: Any) -> tuple[float, float, float]: + """Extract (x, y, z) from various pose representations.""" + if isinstance(p, (list, tuple)): + return (float(p[0]), float(p[1]), float(p[2]) if len(p) > 2 else 0.0) + return (float(p.x), float(p.y), float(getattr(p, "z", 0.0))) + + +@dataclass(frozen=True) +class TagsFilter(Filter): + tags: dict[str, Any] = field(default_factory=dict, hash=False) + + def matches(self, obs: Observation[Any]) -> bool: + for k, v in self.tags.items(): + if obs.tags.get(k) != v: + return False + return True + + +@dataclass(frozen=True) +class PredicateFilter(Filter): + """Wraps an arbitrary predicate function for use with .filter().""" + + fn: Callable[[Observation[Any]], bool] = field(hash=False) + + def matches(self, obs: Observation[Any]) -> bool: + return bool(self.fn(obs)) + + +@dataclass(frozen=True) +class StreamQuery: + filters: tuple[Filter, ...] = () + order_field: str | None = None + order_desc: bool = False + limit_val: int | None = None + offset_val: int | None = None + live_buffer: BackpressureBuffer[Any] | None = None + # Vector search (embedding similarity) + search_vec: Embedding | None = field(default=None, hash=False, compare=False) + search_k: int | None = None + # Full-text search (substring / FTS5) + search_text: str | None = None + + def __str__(self) -> str: + parts: list[str] = [str(f) for f in self.filters] + if self.search_text is not None: + parts.append(f"search({self.search_text!r})") + if self.search_vec is not None: + k = f", k={self.search_k}" if self.search_k is not None else "" + parts.append(f"vector_search({k.lstrip(', ')})" if k else "vector_search()") + if self.order_field: + direction = " DESC" if self.order_desc else "" + parts.append(f"order_by({self.order_field}{direction})") + if self.offset_val: + parts.append(f"offset({self.offset_val})") + if self.limit_val is not None: + parts.append(f"limit({self.limit_val})") + return " | ".join(parts) + + def apply( + self, it: Iterator[Observation[Any]], *, live: bool = False + ) -> Iterator[Observation[Any]]: + """Apply all query operations to an iterator in Python. + + Used as the fallback execution path for transform-sourced streams + and in-memory backends. Backends with native query support (SQL, + ANN indexes) should push down operations instead. + """ + # Filters + if self.filters: + it = (obs for obs in it if all(f.matches(obs) for f in self.filters)) + + # Text search — substring match + if self.search_text is not None: + needle = self.search_text.lower() + it = (obs for obs in it if needle in str(obs.data).lower()) + + # Vector search — brute-force cosine (materializes) + if self.search_vec is not None: + if live: + raise TypeError( + ".search() requires finite data — cannot rank an infinite live stream." + ) + query_emb = self.search_vec + scored = [] + for obs in it: + emb = getattr(obs, "embedding", None) + if emb is not None: + sim = float(emb @ query_emb) + scored.append(obs.derive(data=obs.data, similarity=sim)) + scored.sort(key=lambda o: getattr(o, "similarity", 0.0) or 0.0, reverse=True) + if self.search_k is not None: + scored = scored[: self.search_k] + it = iter(scored) + + # Sort (materializes) + if self.order_field: + if live: + raise TypeError( + ".order_by() requires finite data — cannot sort an infinite live stream." + ) + key = self.order_field + desc = self.order_desc + items = sorted( + list(it), + key=lambda obs: getattr(obs, key) if getattr(obs, key, None) is not None else 0, + reverse=desc, + ) + it = iter(items) + + # Offset + limit + if self.offset_val: + it = islice(it, self.offset_val, None) + if self.limit_val is not None: + it = islice(it, self.limit_val) + + return it diff --git a/dimos/memory2/type/observation.py b/dimos/memory2/type/observation.py new file mode 100644 index 0000000000..0a6dd16ea5 --- /dev/null +++ b/dimos/memory2/type/observation.py @@ -0,0 +1,112 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +import threading +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.models.embedding.base import Embedding + +T = TypeVar("T") + + +class _Unloaded: + """Sentinel indicating data has not been loaded yet.""" + + __slots__ = () + + def __repr__(self) -> str: + return "" + + +_UNLOADED = _Unloaded() + + +@dataclass +class Observation(Generic[T]): + """A single timestamped observation with optional spatial pose and metadata.""" + + id: int + ts: float + pose: Any | None = None + tags: dict[str, Any] = field(default_factory=dict) + _data: T | _Unloaded = field(default=_UNLOADED, repr=False) + _loader: Callable[[], T] | None = field(default=None, repr=False) + _data_lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + + @property + def data(self) -> T: + val = self._data + if isinstance(val, _Unloaded): + with self._data_lock: + # Re-check after acquiring lock (double-checked locking) + val = self._data + if isinstance(val, _Unloaded): + if self._loader is None: + raise LookupError("No data and no loader set on this observation") + loaded = self._loader() + self._data = loaded + self._loader = None # release closure + return loaded + return val # type: ignore[return-value] + return val + + def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: + """Create a new observation preserving ts/pose/tags, replacing data. + + If ``embedding`` is passed, promotes the result to + :class:`EmbeddedObservation`. + """ + if "embedding" in overrides: + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides["embedding"], + similarity=overrides.get("similarity"), + ) + return Observation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + ) + + +@dataclass +class EmbeddedObservation(Observation[T]): + """Observation enriched with a vector embedding and optional similarity score.""" + + embedding: Embedding | None = None + similarity: float | None = None + + def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: + """Preserve embedding unless explicitly replaced.""" + return EmbeddedObservation( + id=self.id, + ts=overrides.get("ts", self.ts), + pose=overrides.get("pose", self.pose), + tags=overrides.get("tags", self.tags), + _data=data, + embedding=overrides.get("embedding", self.embedding), + similarity=overrides.get("similarity", self.similarity), + ) diff --git a/dimos/memory2/utils/formatting.py b/dimos/memory2/utils/formatting.py new file mode 100644 index 0000000000..ee13fb3f36 --- /dev/null +++ b/dimos/memory2/utils/formatting.py @@ -0,0 +1,58 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rich rendering helpers for memory types. + +All rich/ANSI logic lives here. Other modules import the mixin and +``render_text`` — nothing else needs to touch ``rich`` directly. +""" + +from __future__ import annotations + +from rich.console import Console +from rich.text import Text + +_console = Console(force_terminal=True, highlight=False) + + +def render_text(text: Text) -> str: + """Render rich Text to a terminal string with ANSI codes.""" + with _console.capture() as cap: + _console.print(text, end="", soft_wrap=True) + return cap.get() + + +def _colorize(plain: str) -> Text: + """Turn ``'name(args)'``, ``'a | b'``, or ``'a -> b'`` into rich Text with cyan names.""" + t = Text() + pipe = Text(" | ", style="dim") + arrow = Text(" -> ", style="dim") + for i, seg in enumerate(plain.split(" | ")): + if i > 0: + t.append_text(pipe) + for j, part in enumerate(seg.split(" -> ")): + if j > 0: + t.append_text(arrow) + name, _, rest = part.partition("(") + t.append(name, style="cyan") + if rest: + t.append(f"({rest}") + return t + + +class FilterRepr: + """Mixin for filters: subclass defines ``__str__``, gets colored ``__repr__`` free.""" + + def __repr__(self) -> str: + return render_text(_colorize(str(self))) diff --git a/dimos/memory2/utils/sqlite.py b/dimos/memory2/utils/sqlite.py new file mode 100644 index 0000000000..e242a6e1f5 --- /dev/null +++ b/dimos/memory2/utils/sqlite.py @@ -0,0 +1,43 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sqlite3 + +from reactivex.disposable import Disposable + + +def open_sqlite_connection(path: str) -> sqlite3.Connection: + """Open a WAL-mode SQLite connection with sqlite-vec loaded.""" + import sqlite_vec + + conn = sqlite3.connect(path, check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + return conn + + +def open_disposable_sqlite_connection( + path: str, +) -> tuple[Disposable, sqlite3.Connection]: + """Open a WAL-mode SQLite connection and return (disposable, connection). + + The disposable closes the connection when disposed. + """ + conn = open_sqlite_connection(path) + return Disposable(lambda: conn.close()), conn diff --git a/dimos/memory2/utils/validation.py b/dimos/memory2/utils/validation.py new file mode 100644 index 0000000000..636ff59327 --- /dev/null +++ b/dimos/memory2/utils/validation.py @@ -0,0 +1,25 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def validate_identifier(name: str) -> None: + """Reject stream names that aren't safe SQL identifiers.""" + if not _IDENT_RE.match(name): + raise ValueError(f"Invalid stream name: {name!r}") diff --git a/dimos/memory2/vectorstore/base.py b/dimos/memory2/vectorstore/base.py new file mode 100644 index 0000000000..2b26520fd6 --- /dev/null +++ b/dimos/memory2/vectorstore/base.py @@ -0,0 +1,65 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any + +from dimos.core.resource import CompositeResource +from dimos.memory2.registry import qual +from dimos.protocol.service.spec import BaseConfig, Configurable + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class VectorStoreConfig(BaseConfig): + pass + + +class VectorStore(Configurable[VectorStoreConfig], CompositeResource): + """Pluggable storage and ANN index for embedding vectors. + + Separates vector indexing from metadata so backends can swap + search strategies (brute-force, vec0, FAISS, Qdrant) independently. + + Same shape as BlobStore: ``put`` / ``search`` / ``delete``, keyed + by ``(stream, observation_id)``. Vector index creation is lazy — the + first ``put`` for a stream determines dimensionality. + """ + + default_config: type[VectorStoreConfig] = VectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + Configurable.__init__(self, **kwargs) + CompositeResource.__init__(self) + + @abstractmethod + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + """Store an embedding vector for the given stream and observation id.""" + ... + + @abstractmethod + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + """Return top-k (observation_id, similarity) pairs, descending.""" + ... + + @abstractmethod + def delete(self, stream_name: str, key: int) -> None: + """Remove a vector. Silent if missing.""" + ... + + def serialize(self) -> dict[str, Any]: + return {"class": qual(type(self)), "config": self.config.model_dump()} diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py new file mode 100644 index 0000000000..a34ce29108 --- /dev/null +++ b/dimos/memory2/vectorstore/memory.py @@ -0,0 +1,61 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from dimos.memory2.vectorstore.base import VectorStore, VectorStoreConfig + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class MemoryVectorStoreConfig(VectorStoreConfig): + pass + + +class MemoryVectorStore(VectorStore): + """In-memory brute-force vector store for testing. + + Stores embeddings in a dict keyed by ``(stream, observation_id)``. + Search computes cosine similarity against all vectors in the stream. + """ + + default_config: type[MemoryVectorStoreConfig] = MemoryVectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._vectors: dict[str, dict[int, Embedding]] = {} + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + self._vectors.setdefault(stream_name, {})[key] = embedding + + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + vectors = self._vectors.get(stream_name, {}) + if not vectors: + return [] + scored = [(key, float(emb @ query)) for key, emb in vectors.items()] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + def delete(self, stream_name: str, key: int) -> None: + vectors = self._vectors.get(stream_name, {}) + vectors.pop(key, None) diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py new file mode 100644 index 0000000000..fb4613825b --- /dev/null +++ b/dimos/memory2/vectorstore/sqlite.py @@ -0,0 +1,103 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import sqlite3 +from typing import TYPE_CHECKING, Any + +from pydantic import Field, model_validator + +from dimos.memory2.utils.sqlite import open_disposable_sqlite_connection +from dimos.memory2.utils.validation import validate_identifier +from dimos.memory2.vectorstore.base import VectorStore, VectorStoreConfig + +if TYPE_CHECKING: + from dimos.models.embedding.base import Embedding + + +class SqliteVectorStoreConfig(VectorStoreConfig): + conn: sqlite3.Connection | None = Field(default=None, exclude=True) + path: str | None = None + + @model_validator(mode="after") + def _conn_xor_path(self) -> SqliteVectorStoreConfig: + if self.conn is not None and self.path is not None: + raise ValueError("Specify either conn or path, not both") + if self.conn is None and self.path is None: + raise ValueError("Specify either conn or path") + return self + + +class SqliteVectorStore(VectorStore): + """Vector store backed by sqlite-vec's vec0 virtual tables. + + Creates one virtual table per stream: ``"{stream}_vec"``. + Dimensionality is determined lazily on the first ``put()``. + + Supports two construction modes: + + - ``SqliteVectorStore(conn=conn)`` — borrows an externally-managed connection. + - ``SqliteVectorStore(path="file.db")`` — opens and owns its own connection. + """ + + default_config = SqliteVectorStoreConfig + config: SqliteVectorStoreConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._conn: sqlite3.Connection = self.config.conn # type: ignore[assignment] # set in start() if None + self._path = self.config.path + self._tables: dict[str, int] = {} # stream_name -> dimensionality + + def _ensure_table(self, stream_name: str, dim: int) -> None: + if stream_name in self._tables: + return + validate_identifier(stream_name) + self._conn.execute( + f'CREATE VIRTUAL TABLE IF NOT EXISTS "{stream_name}_vec" ' + f"USING vec0(embedding float[{dim}] distance_metric=cosine)" + ) + self._tables[stream_name] = dim + + def start(self) -> None: + if self._conn is None: + assert self._path is not None + disposable, self._conn = open_disposable_sqlite_connection(self._path) + self.register_disposables(disposable) + + def put(self, stream_name: str, key: int, embedding: Embedding) -> None: + vec = embedding.to_numpy().tolist() + self._ensure_table(stream_name, len(vec)) + self._conn.execute( + f'INSERT OR REPLACE INTO "{stream_name}_vec" (rowid, embedding) VALUES (?, ?)', + (key, json.dumps(vec)), + ) + + def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + if stream_name not in self._tables: + return [] + vec = query.to_numpy().tolist() + rows = self._conn.execute( + f'SELECT rowid, distance FROM "{stream_name}_vec" WHERE embedding MATCH ? AND k = ?', + (json.dumps(vec), k), + ).fetchall() + # vec0 cosine distance = 1 - cosine_similarity + return [(int(row[0]), max(0.0, 1.0 - row[1])) for row in rows] + + def delete(self, stream_name: str, key: int) -> None: + if stream_name not in self._tables: + return + self._conn.execute(f'DELETE FROM "{stream_name}_vec" WHERE rowid = ?', (key,)) diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 6fb42b7ccf..10e44f1cc5 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -75,9 +75,9 @@ def embed_text(self, *texts: str) -> Embedding | list[Embedding]: Returns embeddings as torch.Tensor on device for efficient GPU comparisons. """ with torch.inference_mode(): - inputs = self._processor(text=list(texts), return_tensors="pt", padding=True).to( - self.config.device - ) + inputs = self._processor( + text=list(texts), return_tensors="pt", padding=True, truncation=True + ).to(self.config.device) text_features = self._model.get_text_features(**inputs) if self.config.normalize: diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index b68441328a..6fa7ba3d12 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from functools import cached_property from PIL import Image as PILImage @@ -23,6 +24,14 @@ from dimos.msgs.sensor_msgs.Image import Image +class CaptionDetail(Enum): + """Florence-2 caption detail level.""" + + BRIEF = "" + NORMAL = "" + DETAILED = "" + + class Florence2Model(HuggingFaceModel, Captioner): """Florence-2 captioning model from Microsoft. @@ -35,6 +44,7 @@ class Florence2Model(HuggingFaceModel, Captioner): def __init__( self, model_name: str = "microsoft/Florence-2-base", + detail: CaptionDetail = CaptionDetail.NORMAL, **kwargs: object, ) -> None: """Initialize Florence-2 model. @@ -43,9 +53,11 @@ def __init__( model_name: HuggingFace model name. Options: - "microsoft/Florence-2-base" (~0.2B, fastest) - "microsoft/Florence-2-large" (~0.8B, better quality) + detail: Caption detail level **kwargs: Additional config options (device, dtype, warmup, etc.) """ super().__init__(model_name=model_name, **kwargs) + self._task_prompt = detail.value @cached_property def _processor(self) -> AutoProcessor: @@ -53,27 +65,22 @@ def _processor(self) -> AutoProcessor: self.config.model_name, trust_remote_code=self.config.trust_remote_code ) - def caption(self, image: Image, detail: str = "normal") -> str: - """Generate a caption for the image. + _STRIP_PREFIXES = ("The image shows ", "The image is a ", "A ") - Args: - image: Input image to caption - detail: Level of detail for caption: - - "brief": Short, concise caption - - "normal": Standard caption (default) - - "detailed": More detailed description + @staticmethod + def _clean_caption(text: str) -> str: + for prefix in Florence2Model._STRIP_PREFIXES: + if text.startswith(prefix): + return text[len(prefix) :] + return text + + def caption(self, image: Image) -> str: + """Generate a caption for the image. Returns: Text description of the image """ - # Map detail level to Florence-2 task prompts - task_prompts = { - "brief": "", - "normal": "", - "detailed": "", - "more_detailed": "", - } - task_prompt = task_prompts.get(detail, "") + task_prompt = self._task_prompt # Convert to PIL pil_image = PILImage.fromarray(image.to_rgb().data) @@ -101,21 +108,18 @@ def caption(self, image: Image, detail: str = "normal") -> str: # Extract caption from parsed output caption: str = parsed.get(task_prompt, generated_text) - return caption.strip() + return self._clean_caption(caption.strip()) def caption_batch(self, *images: Image) -> list[str]: """Generate captions for multiple images efficiently. - Args: - images: Input images to caption - Returns: List of text descriptions """ if not images: return [] - task_prompt = "" + task_prompt = self._task_prompt # Convert all to PIL pil_images = [PILImage.fromarray(img.to_rgb().data) for img in images] @@ -136,7 +140,7 @@ def caption_batch(self, *images: Image) -> list[str]: ) # Decode all - generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=False) + generated_texts = self._processor.batch_decode(generated_ids, skip_special_tokens=True) # Parse outputs captions = [] @@ -144,7 +148,7 @@ def caption_batch(self, *images: Image) -> list[str]: parsed = self._processor.post_process_generation( text, task=task_prompt, image_size=pil_img.size ) - captions.append(parsed.get(task_prompt, text).strip()) + captions.append(self._clean_caption(parsed.get(task_prompt, text).strip())) return captions diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 66c2876b62..8aee99435d 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -27,7 +27,6 @@ import reactivex as rx from reactivex import operators as ops import rerun as rr -from turbojpeg import TurboJPEG # type: ignore[import-untyped] from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable from dimos.utils.reactive import quality_barrier @@ -377,15 +376,21 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image: @property def sharpness(self) -> float: - """Return sharpness score.""" - gray = self.to_grayscale() - sx = cv2.Sobel(gray.data, cv2.CV_32F, 1, 0, ksize=5) - sy = cv2.Sobel(gray.data, cv2.CV_32F, 0, 1, ksize=5) - magnitude = cv2.magnitude(sx, sy) - mean_mag = float(magnitude.mean()) - if mean_mag <= 0: + """Return sharpness score. + + Downsamples to ~160px wide before computing Laplacian variance + for fast evaluation (~10-20x cheaper than full-res Sobel). + """ + gray = self.to_grayscale().data + # Downsample to ~160px wide for cheap evaluation + h, w = gray.shape[:2] + if w > 160: + scale = 160.0 / w + gray = cv2.resize(gray, (160, int(h * scale)), interpolation=cv2.INTER_AREA) + lap_var = cv2.Laplacian(gray, cv2.CV_32F).var() + if lap_var <= 0: return 0.0 - return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0)) + return float(np.clip((np.log10(lap_var + 1) - 1.0) / 3.0, 0.0, 1.0)) def save(self, filepath: str) -> bool: arr = self.to_opencv() @@ -504,6 +509,8 @@ def lcm_jpeg_encode(self, quality: int = 75, frame_id: str | None = None) -> byt Returns: LCM-encoded bytes with JPEG-compressed image data """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage() @@ -549,6 +556,8 @@ def lcm_jpeg_decode(cls, data: bytes, **kwargs: Any) -> Image: Returns: Image instance """ + from turbojpeg import TurboJPEG # type: ignore[import-untyped] + jpeg = TurboJPEG() msg = LCMImage.lcm_decode(data) diff --git a/dimos/perception/detection/type/detection3d/test_pointcloud.py b/dimos/perception/detection/type/detection3d/test_pointcloud.py index ad1c5cdf1b..2a6d7578e0 100644 --- a/dimos/perception/detection/type/detection3d/test_pointcloud.py +++ b/dimos/perception/detection/type/detection3d/test_pointcloud.py @@ -46,14 +46,14 @@ def test_detection3dpc(detection3dpc) -> None: assert aabb is not None, "Axis-aligned bounding box should not be None" # Verify AABB min values - assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.1) - assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.1) - assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.1) + assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.2) + assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.2) + assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.2) # Verify AABB max values - assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.1) - assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.1) - assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.1) + assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.2) + assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.2) + assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.2) # def test_point_cloud_properties(detection3dpc): """Test point cloud data and boundaries.""" @@ -68,13 +68,13 @@ def test_detection3dpc(detection3dpc) -> None: center = np.mean(points, axis=0) # Verify point cloud boundaries - assert min_pt[0] == pytest.approx(-3.575, abs=0.1) - assert min_pt[1] == pytest.approx(-0.375, abs=0.1) - assert min_pt[2] == pytest.approx(-0.075, abs=0.1) + assert min_pt[0] == pytest.approx(-3.575, abs=0.2) + assert min_pt[1] == pytest.approx(-0.375, abs=0.2) + assert min_pt[2] == pytest.approx(-0.075, abs=0.2) - assert max_pt[0] == pytest.approx(-3.075, abs=0.1) - assert max_pt[1] == pytest.approx(-0.125, abs=0.1) - assert max_pt[2] == pytest.approx(0.475, abs=0.1) + assert max_pt[0] == pytest.approx(-3.075, abs=0.2) + assert max_pt[1] == pytest.approx(-0.125, abs=0.2) + assert max_pt[2] == pytest.approx(0.475, abs=0.2) assert center[0] == pytest.approx(-3.326, abs=0.1) assert center[1] == pytest.approx(-0.202, abs=0.1) diff --git a/dimos/utils/docs/doclinks.py b/dimos/utils/docs/doclinks.py index 2cf5d1702f..4d2fb6dc1c 100644 --- a/dimos/utils/docs/doclinks.py +++ b/dimos/utils/docs/doclinks.py @@ -360,12 +360,13 @@ def replace_code_match(match: re.Match[str]) -> str: resolved_path = resolve_candidates(candidates, file_ref) if resolved_path is None: + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{file_ref}' matches multiple files: {[str(c) for c in candidates]}" + f"'{file_ref}' in {doc_rel} matches multiple files: {[str(c) for c in candidates]}" ) else: - errors.append(f"No file matching '{file_ref}' found in codebase") + errors.append(f"No file matching '{file_ref}' found in codebase (in {doc_rel})") return full_match # Determine line fragment @@ -438,12 +439,13 @@ def replace_link_match(match: re.Match[str]) -> str: if result != full_match: changes.append(f" {link_text}: .md -> {new_link}") return result + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"'{link_text}' matches multiple docs: {[str(c) for c in candidates]}" + f"'{link_text}' in {doc_rel} matches multiple docs: {[str(c) for c in candidates]}" ) else: - errors.append(f"No doc matching '{link_text}' found") + errors.append(f"No doc matching '{link_text}' found (in {doc_rel})") return full_match # Absolute path @@ -460,12 +462,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (fixed broken link)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link: '{raw_link}' does not exist") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: does not exist") return full_match # Relative path — resolve from doc file's directory @@ -475,7 +478,8 @@ def replace_link_match(match: re.Match[str]) -> str: try: rel_to_root = resolved_abs.relative_to(root) except ValueError: - errors.append(f"Link '{raw_link}' resolves outside repo root") + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path + errors.append(f"Link '{raw_link}' in {doc_rel} resolves outside repo root") return full_match if resolved_abs.exists(): @@ -496,12 +500,13 @@ def replace_link_match(match: re.Match[str]) -> str: ) changes.append(f" {link_text}: {raw_link} -> {new_link} (found by search)") return f"[{link_text}]({new_link})" + doc_rel = doc_path.relative_to(root) if doc_path.is_relative_to(root) else doc_path if len(candidates) > 1: errors.append( - f"Broken link '{raw_link}': ambiguous, matches {[str(c) for c in candidates]}" + f"Broken link '{raw_link}' in {doc_rel}: ambiguous, matches {[str(c) for c in candidates]}" ) else: - errors.append(f"Broken link '{raw_link}': target not found") + errors.append(f"Broken link '{raw_link}' in {doc_rel}: target not found") return full_match # Split by ignore regions and only process non-ignored parts diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index a2adc90725..f2fd577d40 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -36,7 +36,7 @@ def get_max_workers() -> int: environment variable, defaulting to 4 times the CPU count. """ env_value = os.getenv("DIMOS_MAX_WORKERS", "") - return int(env_value) if env_value.strip() else multiprocessing.cpu_count() + return int(env_value) if env_value.strip() else min(8, multiprocessing.cpu_count()) # Create a ThreadPoolScheduler with a configurable number of workers. diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index 16b4db1807..d14b281603 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -31,7 +31,8 @@ RUN apt-get update && apt-get install -y \ qtbase5-dev-tools \ supervisor \ iproute2 # for LCM networking system config \ - liblcm-dev + liblcm-dev \ + libturbojpeg0-dev # Fix distutils-installed packages that block pip upgrades RUN apt-get purge -y python3-blinker python3-sympy python3-oauthlib || true diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index db931872bd..09ccb484ed 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -357,7 +357,7 @@ Received 2 messages: {'temperature': 23.0} ``` -See [`memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. +See [`pubsub/impl/memory.py`](/dimos/protocol/pubsub/impl/memory.py) for the complete source. --- diff --git a/pyproject.toml b/pyproject.toml index 1757e01a8c..1fbd29f86f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ dependencies = [ "annotation-protocol>=1.4.0", "lazy_loader", "plum-dispatch==2.5.7", - # Logging "structlog>=25.5.0,<26", "colorlog==6.9.0", @@ -86,6 +85,8 @@ dependencies = [ "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", + "sqlite-vec>=0.1.6", + "lz4>=4.4.5", ] @@ -271,6 +272,7 @@ dev = [ "types-tensorflow>=2.18.0.20251008,<3", "types-tqdm>=4.67.0.20250809,<5", "types-psycopg2>=2.9.21.20251012", + "scipy-stubs>=1.15.0", "types-psutil>=7.2.2.20260130,<8", # Tools @@ -407,6 +409,7 @@ module = [ "rclpy.*", "sam2.*", "sensor_msgs.*", + "sqlite_vec", "std_msgs.*", "tf2_msgs.*", "torchreid", diff --git a/uv.lock b/uv.lock index e6ba8198a8..0d6a3a88ab 100644 --- a/uv.lock +++ b/uv.lock @@ -1686,6 +1686,7 @@ dependencies = [ { name = "dimos-viewer" }, { name = "lazy-loader" }, { name = "llvmlite" }, + { name = "lz4" }, { name = "numba" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -1706,6 +1707,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "sortedcontainers" }, + { name = "sqlite-vec" }, { name = "structlog" }, { name = "terminaltexteffects" }, { name = "textual" }, @@ -1791,6 +1793,8 @@ dds = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -1828,6 +1832,8 @@ dev = [ { name = "python-lsp-server", extra = ["all"] }, { name = "requests-mock" }, { name = "ruff" }, + { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "terminaltexteffects" }, { name = "types-colorama" }, { name = "types-defusedxml" }, @@ -2021,6 +2027,7 @@ requires-dist = [ { name = "lcm", marker = "extra == 'docker'" }, { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, + { name = "lz4", specifier = ">=4.4.5" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, @@ -2090,11 +2097,13 @@ requires-dist = [ { name = "scikit-learn", marker = "extra == 'misc'" }, { name = "scipy", specifier = ">=1.15.1" }, { name = "scipy", marker = "extra == 'docker'", specifier = ">=1.15.1" }, + { name = "scipy-stubs", marker = "extra == 'dev'", specifier = ">=1.15.0" }, { name = "sentence-transformers", marker = "extra == 'misc'" }, { name = "sortedcontainers", specifier = "==2.4.0" }, { name = "sortedcontainers", marker = "extra == 'docker'" }, { name = "sounddevice", marker = "extra == 'agents'" }, { name = "soundfile", marker = "extra == 'web'" }, + { name = "sqlite-vec", specifier = ">=0.1.6" }, { name = "sse-starlette", marker = "extra == 'web'", specifier = ">=2.2.1" }, { name = "structlog", specifier = ">=25.5.0,<26" }, { name = "structlog", marker = "extra == 'docker'", specifier = ">=25.5.0,<26" }, @@ -5556,6 +5565,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/ee/346fa473e666fe14c52fcdd19ec2424157290a032d4c41f98127bfb31ac7/numpy-2.3.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425", size = 12967213, upload-time = "2025-11-16T22:52:39.38Z" }, ] +[[package]] +name = "numpy-typing-compat" +version = "20251206.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/83/dd90774d6685664cbe5525645a50c4e6c7454207aee552918790e879137f/numpy_typing_compat-20251206.2.3.tar.gz", hash = "sha256:18e00e0f4f2040fe98574890248848c7c6831a975562794da186cf4f3c90b935", size = 5009, upload-time = "2025-12-06T20:02:04.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/6f/dde8e2a79a3b6cbc31bc1037c1a1dbc07c90d52d946851bd7cba67e730a8/numpy_typing_compat-20251206.2.3-py3-none-any.whl", hash = "sha256:bfa2e4c4945413e84552cbd34a6d368c88a06a54a896e77ced760521b08f0f61", size = 6300, upload-time = "2025-12-06T20:01:56.664Z" }, +] + [[package]] name = "nvidia-cublas-cu12" version = "12.8.4.1" @@ -6219,6 +6240,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/ec/19c6cc6064c7fc8f0cd6d5b37c4747849e66040c6ca98f86565efc2c227c/optax-0.2.6-py3-none-any.whl", hash = "sha256:f875251a5ab20f179d4be57478354e8e21963373b10f9c3b762b94dcb8c36d91", size = 367782, upload-time = "2025-09-15T22:41:22.825Z" }, ] +[[package]] +name = "optype" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/3c/9d59b0167458b839273ad0c4fc5f62f787058d8f5aed7f71294963a99471/optype-0.9.3.tar.gz", hash = "sha256:5f09d74127d316053b26971ce441a4df01f3a01943601d3712dd6f34cdfbaf48", size = 96143, upload-time = "2025-03-31T17:00:08.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/d8/ac50e2982bdc2d3595dc2bfe3c7e5a0574b5e407ad82d70b5f3707009671/optype-0.9.3-py3-none-any.whl", hash = "sha256:2935c033265938d66cc4198b0aca865572e635094e60e6e79522852f029d9e8d", size = 84357, upload-time = "2025-03-31T17:00:06.464Z" }, +] + +[[package]] +name = "optype" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/d3/c88bb4bd90867356275ca839499313851af4b36fce6919ebc5e1de26e7ca/optype-0.16.0.tar.gz", hash = "sha256:fa682fd629ef6b70ba656ebc9fdd6614ba06ce13f52e0416dd8014c7e691a2d1", size = 53498, upload-time = "2026-02-19T23:37:09.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a8/fe26515203cff140f1afc31236fb7f703d4bb4bd5679d28afcb3661c8d9f/optype-0.16.0-py3-none-any.whl", hash = "sha256:c28905713f55630b4bb8948f38e027ad13a541499ebcf957501f486da54b74d2", size = 65893, upload-time = "2026-02-19T23:37:08.217Z" }, +] + +[package.optional-dependencies] +numpy = [ + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "numpy-typing-compat", marker = "python_full_version >= '3.11'" }, +] + [[package]] name = "orbax-checkpoint" version = "0.11.32" @@ -8920,6 +8995,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/a5/df8f46ef7da168f1bc52cd86e09a9de5c6f19cc1da04454d51b7d4f43408/scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff", size = 25246266, upload-time = "2026-01-10T21:30:45.923Z" }, ] +[[package]] +name = "scipy-stubs" +version = "1.15.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.11' and sys_platform == 'win32'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/5f/35c43bd7d412add4adcd68475702571b2489b50c40b6564f808b2355e452/scipy_stubs-1.15.3.0.tar.gz", hash = "sha256:e8f76c9887461cf9424c1e2ad78ea5dac71dd4cbb383dc85f91adfe8f74d1e17", size = 275699, upload-time = "2025-05-08T16:58:35.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/42/cd8dc81f8060de1f14960885ad5b2d2651f41de8b93d09f3f919d6567a5a/scipy_stubs-1.15.3.0-py3-none-any.whl", hash = "sha256:a251254cf4fd6e7fb87c55c1feee92d32ddbc1f542ecdf6a0159cdb81c2fb62d", size = 459062, upload-time = "2025-05-08T16:58:33.356Z" }, +] + +[[package]] +name = "scipy-stubs" +version = "1.17.1.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", +] +dependencies = [ + { name = "optype", version = "0.16.0", source = { registry = "https://pypi.org/simple" }, extra = ["numpy"], marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/ad/413b0d18efca7bb48574d28e91253409d91ee6121e7937022d0d380dfc6a/scipy_stubs-1.17.1.0.tar.gz", hash = "sha256:5dc51c21765b145c2d132b96b63ff4f835dd5fb768006876d1554e7a59c61571", size = 381420, upload-time = "2026-02-23T10:33:04.742Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/ee/c6811e04ff9d5dd1d92236e8df7ebc4db6aa65c70b9938cec293348b8ec4/scipy_stubs-1.17.1.0-py3-none-any.whl", hash = "sha256:5c9c84993d36b104acb2d187b05985eb79f73491c60d83292dd738093d53d96a", size = 587059, upload-time = "2026-02-23T10:33:02.845Z" }, +] + [[package]] name = "sentence-transformers" version = "5.2.2" @@ -9114,6 +9237,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, ] +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ed/aabc328f29ee6814033d008ec43e44f2c595447d9cccd5f2aabe60df2933/sqlite_vec-0.1.6-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:77491bcaa6d496f2acb5cc0d0ff0b8964434f141523c121e313f9a7d8088dee3", size = 164075, upload-time = "2024-11-20T16:40:29.847Z" }, + { url = "https://files.pythonhosted.org/packages/a7/57/05604e509a129b22e303758bfa062c19afb020557d5e19b008c64016704e/sqlite_vec-0.1.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fdca35f7ee3243668a055255d4dee4dea7eed5a06da8cad409f89facf4595361", size = 165242, upload-time = "2024-11-20T16:40:31.206Z" }, + { url = "https://files.pythonhosted.org/packages/f2/48/dbb2cc4e5bad88c89c7bb296e2d0a8df58aab9edc75853728c361eefc24f/sqlite_vec-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b0519d9cd96164cd2e08e8eed225197f9cd2f0be82cb04567692a0a4be02da3", size = 103704, upload-time = "2024-11-20T16:40:33.729Z" }, + { url = "https://files.pythonhosted.org/packages/80/76/97f33b1a2446f6ae55e59b33869bed4eafaf59b7f4c662c8d9491b6a714a/sqlite_vec-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:823b0493add80d7fe82ab0fe25df7c0703f4752941aee1c7b2b02cec9656cb24", size = 151556, upload-time = "2024-11-20T16:40:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/6a/98/e8bc58b178266eae2fcf4c9c7a8303a8d41164d781b32d71097924a6bebe/sqlite_vec-0.1.6-py3-none-win_amd64.whl", hash = "sha256:c65bcfd90fa2f41f9000052bcb8bb75d38240b2dae49225389eca6c3136d3f0c", size = 281540, upload-time = "2024-11-20T16:40:37.296Z" }, +] + [[package]] name = "sse-starlette" version = "3.2.0" From e6267e11ae750af64d0ddfbe6cbc9fce658b3b2b Mon Sep 17 00:00:00 2001 From: stash Date: Sun, 15 Mar 2026 15:31:02 +0800 Subject: [PATCH 12/42] docs(readme): add Trendshift trending badge (#1563) --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 0aaa2c111b..62f9f464f0 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ ![CUDA](https://img.shields.io/badge/CUDA-supported-76B900?style=flat-square&logo=nvidia&logoColor=white) [![Docker](https://img.shields.io/badge/Docker-ready-2496ED?style=flat-square&logo=docker&logoColor=white)](https://www.docker.com/) +dimensionalOS%2Fdimos | Trendshift + [Hardware](#hardware) • From 9486a90a7e25ab15c46567a187c3068adb21b5a6 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Mon, 16 Mar 2026 12:16:12 +0200 Subject: [PATCH 13/42] fix(ci): limit tests to 60 minutes max (#1557) --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index da50491f54..14eaaabab9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,6 +21,7 @@ permissions: jobs: run-tests: + timeout-minutes: 60 runs-on: [self-hosted, Linux] container: image: ghcr.io/dimensionalos/${{ inputs.dev-image }} From 3a5f000442ab0babc92036aec68459a332376ccb Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Mon, 16 Mar 2026 14:27:17 +0200 Subject: [PATCH 14/42] fix(old-scripts): remove (#1561) --- bin/filter-errors-after-date | 77 -------------- bin/filter-errors-for-user | 63 ------------ bin/mypy-ros | 44 -------- bin/re-ignore-mypy.py | 150 ---------------------------- bin/robot-debugger | 36 ------- dimos/robot/utils/README.md | 38 ------- dimos/robot/utils/robot_debugger.py | 59 ----------- 7 files changed, 467 deletions(-) delete mode 100755 bin/filter-errors-after-date delete mode 100755 bin/filter-errors-for-user delete mode 100755 bin/mypy-ros delete mode 100755 bin/re-ignore-mypy.py delete mode 100755 bin/robot-debugger delete mode 100644 dimos/robot/utils/README.md delete mode 100644 dimos/robot/utils/robot_debugger.py diff --git a/bin/filter-errors-after-date b/bin/filter-errors-after-date deleted file mode 100755 index 03c7de0ca7..0000000000 --- a/bin/filter-errors-after-date +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 - -# Used to filter errors to only show lines committed on or after a specific date -# Can be chained with filter-errors-for-user - -from datetime import datetime -import re -import subprocess -import sys - -_blame = {} - - -def _is_after_date(file, line_no, cutoff_date): - if file not in _blame: - _blame[file] = _get_git_blame_dates_for_file(file) - line_date = _blame[file].get(line_no) - if not line_date: - return False - return line_date >= cutoff_date - - -def _get_git_blame_dates_for_file(file_name): - try: - result = subprocess.run( - ["git", "blame", "--date=short", file_name], - capture_output=True, - text=True, - check=True, - ) - - blame_map = {} - # Each line looks like: ^abc123 (Author Name 2024-01-01 1) code - blame_pattern = re.compile(r"^[^\(]+\([^\)]+(\d{4}-\d{2}-\d{2})") - - for i, line in enumerate(result.stdout.split("\n")): - if not line: - continue - match = blame_pattern.match(line) - if match: - date_str = match.group(1) - blame_map[str(i + 1)] = date_str - - return blame_map - except subprocess.CalledProcessError: - return {} - - -def main(): - if len(sys.argv) != 2: - print("Usage: filter-errors-after-date ", file=sys.stderr) - print(" Example: filter-errors-after-date 2025-10-04", file=sys.stderr) - sys.exit(1) - - cutoff_date = sys.argv[1] - - try: - datetime.strptime(cutoff_date, "%Y-%m-%d") - except ValueError: - print(f"Error: Invalid date format '{cutoff_date}'. Use YYYY-MM-DD", file=sys.stderr) - sys.exit(1) - - for line in sys.stdin.readlines(): - split = re.findall(r"^([^:]+):(\d+):(.*)", line) - if not split or len(split[0]) != 3: - continue - - file, line_no = split[0][:2] - if not file.startswith("dimos/"): - continue - - if _is_after_date(file, line_no, cutoff_date): - print(":".join(split[0])) - - -if __name__ == "__main__": - main() diff --git a/bin/filter-errors-for-user b/bin/filter-errors-for-user deleted file mode 100755 index 045b30b293..0000000000 --- a/bin/filter-errors-for-user +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 - -# Used when running `./bin/mypy-strict --for-me` - -import re -import subprocess -import sys - -_blame = {} - - -def _is_for_user(file, line_no, user_email): - if file not in _blame: - _blame[file] = _get_git_blame_for_file(file) - return _blame[file][line_no] == user_email - - -def _get_git_blame_for_file(file_name): - try: - result = subprocess.run( - ["git", "blame", "--show-email", "-e", file_name], - capture_output=True, - text=True, - check=True, - ) - - blame_map = {} - # Each line looks like: ^abc123 ( 2024-01-01 12:00:00 +0000 1) code - blame_pattern = re.compile(r"^[^\(]+\(<([^>]+)>") - - for i, line in enumerate(result.stdout.split("\n")): - if not line: - continue - match = blame_pattern.match(line) - if match: - email = match.group(1) - blame_map[str(i + 1)] = email - - return blame_map - except subprocess.CalledProcessError: - return {} - - -def main(): - if len(sys.argv) != 2: - print("Usage: filter-errors-for-user ", file=sys.stderr) - sys.exit(1) - - user_email = sys.argv[1] - - for line in sys.stdin.readlines(): - split = re.findall(r"^([^:]+):(\d+):(.*)", line) - if not split or len(split[0]) != 3: - continue - file, line_no = split[0][:2] - if not file.startswith("dimos/"): - continue - if _is_for_user(file, line_no, user_email): - print(":".join(split[0])) - - -if __name__ == "__main__": - main() diff --git a/bin/mypy-ros b/bin/mypy-ros deleted file mode 100755 index d46d6a542e..0000000000 --- a/bin/mypy-ros +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -set -euo pipefail - -ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" - -mypy_args=(--show-error-codes --hide-error-context --no-pretty) - -main() { - cd "$ROOT" - - if [ -z "$(docker images -q dimos-ros-dev)" ]; then - (cd docker/ros; docker build -t dimos-ros .) - docker build -t dimos-ros-python --build-arg FROM_IMAGE=dimos-ros -f docker/python/Dockerfile . - docker build -t dimos-ros-dev --build-arg FROM_IMAGE=dimos-ros-python -f docker/dev/Dockerfile . - fi - - sudo rm -fr .mypy_cache_docker - rm -fr .mypy_cache_local - - { - mypy_docker & - mypy_local & - wait - } | sort -u -} - -cleaned() { - grep ': error: ' | sort -} - -mypy_docker() { - docker run --rm -v $(pwd):/app -w /app dimos-ros-dev bash -c " - source /opt/ros/humble/setup.bash && - MYPYPATH=/opt/ros/humble/lib/python3.10/site-packages mypy ${mypy_args[*]} --cache-dir .mypy_cache_docker dimos - " | cleaned -} - -mypy_local() { - MYPYPATH=/opt/ros/jazzy/lib/python3.12/site-packages \ - mypy "${mypy_args[@]}" --cache-dir .mypy_cache_local dimos | cleaned -} - -main "$@" diff --git a/bin/re-ignore-mypy.py b/bin/re-ignore-mypy.py deleted file mode 100755 index 7d71bcd986..0000000000 --- a/bin/re-ignore-mypy.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict -from pathlib import Path -import re -import subprocess - - -def remove_type_ignore_comments(directory: Path) -> None: - # Pattern matches "# type: ignore" with optional error codes in brackets. - # Captures any trailing comment after `type: ignore`. - type_ignore_pattern = re.compile(r"(\s*)#\s*type:\s*ignore(?:\[[^\]]*\])?(\s*#.*)?") - - for py_file in directory.rglob("*.py"): - try: - content = py_file.read_text() - except Exception: - continue - - new_lines = [] - modified = False - - for line in content.splitlines(keepends=True): - match = type_ignore_pattern.search(line) - if match: - before = line[: match.start()] - trailing_comment = match.group(2) - - if trailing_comment: - new_line = before + match.group(1) + trailing_comment.lstrip() - else: - new_line = before - - if line.endswith("\n"): - new_line = new_line.rstrip() + "\n" - else: - new_line = new_line.rstrip() - new_lines.append(new_line) - modified = True - else: - new_lines.append(line) - - if modified: - try: - py_file.write_text("".join(new_lines)) - except Exception: - pass - - -def run_mypy(root: Path) -> str: - result = subprocess.run( - [str(root / "bin" / "mypy-ros")], - capture_output=True, - text=True, - cwd=root, - ) - return result.stdout + result.stderr - - -def parse_mypy_errors(output: str) -> dict[Path, dict[int, list[str]]]: - error_pattern = re.compile(r"^(.+):(\d+): error: .+\[([^\]]+)\]\s*$") - errors: dict[Path, dict[int, list[str]]] = defaultdict(lambda: defaultdict(list)) - - for line in output.splitlines(): - match = error_pattern.match(line) - if match: - file_path = Path(match.group(1)) - line_num = int(match.group(2)) - error_code = match.group(3) - if error_code not in errors[file_path][line_num]: - errors[file_path][line_num].append(error_code) - - return errors - - -def add_type_ignore_comments(root: Path, errors: dict[Path, dict[int, list[str]]]) -> None: - comment_pattern = re.compile(r"^([^#]*?)( #.*)$") - - for file_path, line_errors in errors.items(): - full_path = root / file_path - if not full_path.exists(): - continue - - try: - content = full_path.read_text() - except Exception: - continue - - lines = content.splitlines(keepends=True) - modified = False - - for line_num, error_codes in line_errors.items(): - if line_num < 1 or line_num > len(lines): - continue - - idx = line_num - 1 - line = lines[idx] - codes_str = ", ".join(sorted(error_codes)) - ignore_comment = f" # type: ignore[{codes_str}]" - - has_newline = line.endswith("\n") - line_content = line.rstrip("\n") - - comment_match = comment_pattern.match(line_content) - if comment_match: - code_part = comment_match.group(1) - existing_comment = comment_match.group(2) - new_line = code_part + ignore_comment + existing_comment - else: - new_line = line_content + ignore_comment - - if has_newline: - new_line += "\n" - - lines[idx] = new_line - modified = True - - if modified: - try: - full_path.write_text("".join(lines)) - except Exception: - pass - - -def main() -> None: - root = Path(__file__).parent.parent - dimos_dir = root / "dimos" - - remove_type_ignore_comments(dimos_dir) - mypy_output = run_mypy(root) - errors = parse_mypy_errors(mypy_output) - add_type_ignore_comments(root, errors) - - -if __name__ == "__main__": - main() diff --git a/bin/robot-debugger b/bin/robot-debugger deleted file mode 100755 index 165a546a0c..0000000000 --- a/bin/robot-debugger +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -# Control the robot with a python shell (for debugging). -# -# You have to start the robot run file with: -# -# ROBOT_DEBUGGER=true python -# -# And now start this script -# -# $ ./bin/robot-debugger -# >>> robot.explore() -# True -# >>> - - -exec python -i <(cat < 0: - print("\nConnected.") - break - except ConnectionRefusedError: - print("Not started yet. Trying again...") - time.sleep(2) -else: - print("Failed to connect. Is it started?") - exit(1) - -robot = c.root.robot() -EOF -) diff --git a/dimos/robot/utils/README.md b/dimos/robot/utils/README.md deleted file mode 100644 index 5a84b20c4a..0000000000 --- a/dimos/robot/utils/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Robot Utils - -## RobotDebugger - -The `RobotDebugger` provides a way to debug a running robot through the python shell. - -Requirements: - -```bash -pip install rpyc -``` - -### Usage - -1. **Add to your robot application:** - ```python - from dimos.robot.utils.robot_debugger import RobotDebugger - - # In your robot application's context manager or main loop: - with RobotDebugger(robot): - # Your robot code here - pass - - # Or better, with an exit stack. - exit_stack.enter_context(RobotDebugger(robot)) - ``` - -2. **Start your robot with debugging enabled:** - ```bash - ROBOT_DEBUGGER=true python your_robot_script.py - ``` - -3. **Open the python shell:** - ```bash - ./bin/robot-debugger - >>> robot.explore() - True - ``` diff --git a/dimos/robot/utils/robot_debugger.py b/dimos/robot/utils/robot_debugger.py deleted file mode 100644 index c7f3cd7291..0000000000 --- a/dimos/robot/utils/robot_debugger.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from dimos.core.resource import Resource -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class RobotDebugger(Resource): - def __init__(self, robot) -> None: # type: ignore[no-untyped-def] - self._robot = robot - self._threaded_server = None - - def start(self) -> None: - if not os.getenv("ROBOT_DEBUGGER"): - return - - try: - import rpyc # type: ignore[import-not-found] - from rpyc.utils.server import ThreadedServer # type: ignore[import-not-found] - except ImportError: - return - - logger.info( - "Starting the robot debugger. You can open a python shell with `./bin/robot-debugger`" - ) - - robot = self._robot - - class RobotService(rpyc.Service): # type: ignore[misc] - def exposed_robot(self): # type: ignore[no-untyped-def] - return robot - - self._threaded_server = ThreadedServer( - RobotService, - port=18861, - protocol_config={ - "allow_all_attrs": True, - }, - ) - self._threaded_server.start() # type: ignore[attr-defined] - - def stop(self) -> None: - if self._threaded_server: - self._threaded_server.close() From ab081c1acae894a398fc7421021aa1c97955a071 Mon Sep 17 00:00:00 2001 From: stash Date: Mon, 16 Mar 2026 21:06:06 +0800 Subject: [PATCH 15/42] docs: add Spec issue template (#1574) --- .github/ISSUE_TEMPLATE/spec.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/spec.yml diff --git a/.github/ISSUE_TEMPLATE/spec.yml b/.github/ISSUE_TEMPLATE/spec.yml new file mode 100644 index 0000000000..016399f1dc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/spec.yml @@ -0,0 +1,28 @@ +name: Spec +description: Technical specification for a new module, feature, or system change +title: "[Spec]: " +body: + - type: textarea + id: spec + attributes: + label: Specification + description: Full technical spec in markdown + value: | + ## Summary + + + ## Motivation + + + ## Design + + ### API / Interface + + + ### Architecture + + + ### Implementation Notes + + validations: + required: true From e95e0d7be9f713a77156bd86ab8194f25d94ab65 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 17 Mar 2026 02:50:34 +0200 Subject: [PATCH 16/42] feat(patrol): add patrolling module (#1488) --- dimos/agents/agent.py | 78 +++++++- dimos/agents/mcp/mcp_client.py | 65 ++++++- dimos/agents/skills/person_follow.py | 76 ++++++-- dimos/core/global_config.py | 3 +- dimos/e2e_tests/conftest.py | 9 + dimos/e2e_tests/test_patrol_and_follow.py | 86 +++++++++ dimos/mapping/occupancy/gradient.py | 10 +- dimos/mapping/occupancy/path_map.py | 14 +- dimos/mapping/occupancy/test_path_map.py | 2 +- dimos/models/segmentation/edge_tam.py | 5 +- .../patrolling/create_patrol_router.py | 45 +++++ dimos/navigation/patrolling/module.py | 143 +++++++++++++++ .../patrolling/patrolling_module_spec.py | 27 +++ .../patrolling/routers/base_patrol_router.py | 78 ++++++++ .../routers/coverage_patrol_router.py | 164 +++++++++++++++++ .../routers/frontier_patrol_router.py | 125 +++++++++++++ .../patrolling/routers/patrol_router.py | 27 +++ .../routers/random_patrol_router.py | 68 +++++++ .../patrolling/routers/visitation_history.py | 121 ++++++++++++ .../patrolling/test_create_patrol_router.py | 101 ++++++++++ dimos/navigation/patrolling/utilities.py | 26 +++ .../replanning_a_star/controllers.py | 8 +- .../replanning_a_star/global_planner.py | 47 ++++- .../replanning_a_star/local_planner.py | 6 +- .../replanning_a_star/min_cost_astar.py | 3 +- .../replanning_a_star/min_cost_astar_cpp.cpp | 4 +- dimos/navigation/replanning_a_star/module.py | 12 ++ .../replanning_a_star/module_spec.py | 29 +++ .../replanning_a_star/navigation_map.py | 8 +- .../replanning_a_star/position_tracker.py | 4 +- .../replanning_a_star/test_min_cost_astar.py | 48 +++++ .../detection/type/detection2d/test_bbox.py | 1 + .../test_temporal_memory_module.py | 21 +-- dimos/perception/perceive_loop_skill.py | 86 ++++++++- .../perception/test_spatial_memory_module.py | 7 +- .../go2/blueprints/smart/unitree_go2.py | 2 + .../mujoco/direct_cmd_vel_explorer.py | 107 +++++++++++ .../navigation/native/assets/coverage.png | 3 + .../navigation/native/assets/frontier.png | 3 + .../navigation/native/assets/patrol_path.png | 3 + .../navigation/native/assets/random.png | 3 + docs/capabilities/navigation/native/index.md | 36 ++++ misc/optimize_patrol/optimize_candidates.py | 126 +++++++++++++ .../optimize_candidates_child.py | 173 ++++++++++++++++++ .../optimize_patrol/optimize_patrol_router.py | 117 ++++++++++++ .../optimize_patrol_router_child.py | 155 ++++++++++++++++ misc/optimize_patrol/plot_path.py | 131 +++++++++++++ pyproject.toml | 12 +- uv.lock | 4 +- 49 files changed, 2360 insertions(+), 72 deletions(-) create mode 100644 dimos/e2e_tests/test_patrol_and_follow.py create mode 100644 dimos/navigation/patrolling/create_patrol_router.py create mode 100644 dimos/navigation/patrolling/module.py create mode 100644 dimos/navigation/patrolling/patrolling_module_spec.py create mode 100644 dimos/navigation/patrolling/routers/base_patrol_router.py create mode 100644 dimos/navigation/patrolling/routers/coverage_patrol_router.py create mode 100644 dimos/navigation/patrolling/routers/frontier_patrol_router.py create mode 100644 dimos/navigation/patrolling/routers/patrol_router.py create mode 100644 dimos/navigation/patrolling/routers/random_patrol_router.py create mode 100644 dimos/navigation/patrolling/routers/visitation_history.py create mode 100644 dimos/navigation/patrolling/test_create_patrol_router.py create mode 100644 dimos/navigation/patrolling/utilities.py create mode 100644 dimos/navigation/replanning_a_star/module_spec.py create mode 100644 dimos/simulation/mujoco/direct_cmd_vel_explorer.py create mode 100644 docs/capabilities/navigation/native/assets/coverage.png create mode 100644 docs/capabilities/navigation/native/assets/frontier.png create mode 100644 docs/capabilities/navigation/native/assets/patrol_path.png create mode 100644 docs/capabilities/navigation/native/assets/random.png create mode 100644 misc/optimize_patrol/optimize_candidates.py create mode 100644 misc/optimize_patrol/optimize_candidates_child.py create mode 100644 misc/optimize_patrol/optimize_patrol_router.py create mode 100644 misc/optimize_patrol/optimize_patrol_router_child.py create mode 100644 misc/optimize_patrol/plot_path.py diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 6e24cee870..672d30c3de 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -32,6 +32,9 @@ from dimos.core.stream import In, Out from dimos.protocol.rpc.spec import RPCSpec from dimos.spec.utils import Spec +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel @@ -52,6 +55,7 @@ class Agent(Module[AgentConfig]): _lock: RLock _state_graph: CompiledStateGraph[Any, Any, Any, Any] | None _message_queue: Queue[BaseMessage] + _skill_registry: dict[str, SkillInfo] _history: list[BaseMessage] _thread: Thread _stop_event: Event @@ -62,6 +66,7 @@ def __init__(self, **kwargs: Any) -> None: self._state_graph = None self._message_queue = Queue() self._history = [] + self._skill_registry = {} self._thread = Thread( target=self._thread_loop, name=f"{self.__class__.__name__}-thread", @@ -100,13 +105,16 @@ def on_system_modules(self, modules: list[RPCClient]) -> None: model = MockModel(json_path=self.config.model_fixture) + skills = [skill for module in modules for skill in (module.get_skills() or [])] + self._skill_registry = {skill.func_name: skill for skill in skills} + with self._lock: # Here to prevent unwanted imports in the file. from langchain.agents import create_agent self._state_graph = create_agent( model=model, - tools=_get_tools_from_modules(self, modules, self.rpc), + tools=[_skill_to_tool(self, skill, self.rpc) for skill in skills], system_prompt=self.config.system_prompt, ) self._thread.start() @@ -115,6 +123,64 @@ def on_system_modules(self, modules: list[RPCClient]) -> None: def add_message(self, message: BaseMessage) -> None: self._message_queue.put(message) + @rpc + def dispatch_continuation( + self, continuation: dict[str, Any], continuation_context: dict[str, Any] + ) -> None: + """Execute a tool continuation with detection data, bypassing the LLM. + + Called by trigger tools (e.g. look_out_for) to immediately invoke a + follow-up tool when a detection fires, without waiting for the LLM to + reason about the next action. + + Args: + continuation: ``{"tool": "", "args": {…}}`` — the tool to + call and its arguments. Argument values that are strings + starting with ``$`` are treated as template variables and + resolved against *continuation_context* (e.g. ``"$bbox"``). + continuation_context: runtime detection data, e.g. + ``{"bbox": [x1, y1, x2, y2], "label": "person"}``. + """ + tool_name = continuation.get("tool") + if not tool_name: + self._message_queue.put( + HumanMessage(f"Continuation failed: missing 'tool' key in {continuation}") + ) + return + + skill_info = self._skill_registry.get(tool_name) + if skill_info is None: + self._message_queue.put( + HumanMessage(f"Continuation failed: tool '{tool_name}' not found") + ) + return + + tool_args: dict[str, Any] = dict(continuation.get("args", {})) + + # Substitute $-prefixed template variables from continuation_context + for key, value in tool_args.items(): + if isinstance(value, str) and value.startswith("$"): + context_key = value[1:] + if context_key in continuation_context: + tool_args[key] = continuation_context[context_key] + + rpc_call = RpcCall(None, self.rpc, skill_info.func_name, skill_info.class_name, []) + try: + result = rpc_call(**tool_args) + except Exception as e: + self._message_queue.put( + HumanMessage(f"Continuation '{tool_name}' failed with error: {e}") + ) + return + + label = continuation_context.get("label", "unknown") + self._message_queue.put( + HumanMessage( + f"Automatically executed '{tool_name}' as a continuation of lookout " + f"detection (detected: {label}). Result: {result or 'started'}" + ) + ) + def _thread_loop(self) -> None: while not self._stop_event.is_set(): try: @@ -148,13 +214,9 @@ def _process_message( class AgentSpec(Spec, Protocol): def add_message(self, message: BaseMessage) -> None: ... - - -def _get_tools_from_modules( - agent: Agent, modules: list[RPCClient], rpc: RPCSpec -) -> list[StructuredTool]: - skills = [skill for module in modules for skill in (module.get_skills() or [])] - return [_skill_to_tool(agent, skill, rpc) for skill in skills] + def dispatch_continuation( + self, continuation: dict[str, Any], continuation_context: dict[str, Any] + ) -> None: ... def _skill_to_tool(agent: Agent, skill: SkillInfo, rpc: RPCSpec) -> StructuredTool: diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index a2ee872e16..b32d195de8 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -54,6 +54,7 @@ class McpClient(Module[McpClientConfig]): _lock: RLock _state_graph: CompiledStateGraph[Any, Any, Any, Any] | None _message_queue: Queue[BaseMessage] + _tool_registry: dict[str, dict[str, Any]] _history: list[BaseMessage] _thread: Thread _stop_event: Event @@ -65,6 +66,7 @@ def __init__(self, **kwargs: Any) -> None: self._lock = RLock() self._state_graph = None self._message_queue = Queue() + self._tool_registry = {} self._history = [] self._thread = Thread( target=self._thread_loop, @@ -104,7 +106,9 @@ def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[Str f"Failed to fetch tools from MCP server {self.config.mcp_server_url}" ) - tools = [self._mcp_tool_to_langchain(t) for t in result.get("tools", [])] + raw_tools = result.get("tools", []) + self._tool_registry = {t["name"]: t for t in raw_tools} + tools = [self._mcp_tool_to_langchain(t) for t in raw_tools] if not tools: logger.warning("No tools found from MCP server.") @@ -196,6 +200,65 @@ def stop(self) -> None: def add_message(self, message: BaseMessage) -> None: self._message_queue.put(message) + @rpc + def dispatch_continuation( + self, continuation: dict[str, Any], continuation_context: dict[str, Any] + ) -> None: + """Execute a tool continuation with detection data, bypassing the LLM. + + Called by trigger tools (e.g. look_out_for) to immediately invoke a + follow-up tool when a detection fires, without waiting for the LLM to + reason about the next action. + + Args: + continuation: ``{"tool": "", "args": {…}}`` — the tool to + call and its arguments. Argument values that are strings + starting with ``$`` are treated as template variables and + resolved against *continuation_context* (e.g. ``"$bbox"``). + continuation_context: runtime detection data, e.g. + ``{"bbox": [x1, y1, x2, y2], "label": "person"}``. + """ + tool_name = continuation.get("tool") + if not tool_name: + self._message_queue.put( + HumanMessage(f"Continuation failed: missing 'tool' key in {continuation}") + ) + return + + if tool_name not in self._tool_registry: + self._message_queue.put( + HumanMessage(f"Continuation failed: tool '{tool_name}' not found") + ) + return + + tool_args: dict[str, Any] = dict(continuation.get("args", {})) + + # Substitute $-prefixed template variables from continuation_context + for key, value in tool_args.items(): + if isinstance(value, str) and value.startswith("$"): + context_key = value[1:] + if context_key in continuation_context: + tool_args[key] = continuation_context[context_key] + + try: + result = self._mcp_request("tools/call", {"name": tool_name, "arguments": tool_args}) + content = result.get("content", []) + parts = [c.get("text", "") for c in content if c.get("type") == "text"] + text = "\n".join(parts) + except Exception as e: + self._message_queue.put( + HumanMessage(f"Continuation '{tool_name}' failed with error: {e}") + ) + return + + label = continuation_context.get("label", "unknown") + self._message_queue.put( + HumanMessage( + f"Automatically executed '{tool_name}' as a continuation of lookout " + f"detection (detected: {label}). Result: {text or 'started'}" + ) + ) + def _thread_loop(self) -> None: while not self._stop_event.is_set(): try: diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index f1cafed6cd..563fcd4f59 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 from threading import Event, RLock, Thread import time from typing import Any @@ -19,6 +20,7 @@ from langchain_core.messages import HumanMessage import numpy as np from reactivex.disposable import Disposable +from turbojpeg import TurboJPEG from dimos.agents.agent import AgentSpec from dimos.agents.annotation import skill @@ -31,8 +33,9 @@ from dimos.models.vl.create import create from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.patrolling.patrolling_module_spec import PatrollingModuleSpec from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D @@ -65,6 +68,7 @@ class PersonFollowSkillContainer(Module[Config]): _agent_spec: AgentSpec _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost + _patrolling_module_spec: PatrollingModuleSpec def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -106,7 +110,12 @@ def stop(self) -> None: super().stop() @skill - def follow_person(self, query: str) -> str: + def follow_person( + self, + query: str, + initial_bbox: list[float] | None = None, + initial_image: str | None = None, + ) -> str: """Follow a person matching the given description using visual servoing. The robot will continuously track and follow the person, while keeping @@ -114,6 +123,12 @@ def follow_person(self, query: str) -> str: Args: query: Description of the person to follow (e.g., "man with blue shirt") + initial_bbox: Optional pre-computed bounding box [x1, y1, x2, y2]. + If provided, skips the initial VL model detection step. This is + used by the continuation system to pass detection data directly + from look_out_for, avoiding a redundant detection. + initial_image: Optional base64-encoded JPEG of the frame on which + initial_bbox was detected. Returns: Status message indicating the result of the following action. @@ -133,16 +148,27 @@ def follow_person(self, query: str) -> str: if latest_image is None: return "No image available to detect person." - initial_bbox = get_object_bbox_from_image( - self._vl_model, - latest_image, - query, - ) - - if initial_bbox is None: - return f"Could not find '{query}' in the current view." - - return self._follow_person(query, initial_bbox) + detection_image: Image | None = None + if initial_bbox is not None: + bbox: BBox = ( + initial_bbox[0], + initial_bbox[1], + initial_bbox[2], + initial_bbox[3], + ) + if initial_image is not None: + detection_image = _decode_base64_image(initial_image) + else: + detected = get_object_bbox_from_image( + self._vl_model, + latest_image, + query, + ) + if detected is None: + return f"Could not find '{query}' in the current view." + bbox = detected + + return self._follow_person(query, bbox, detection_image) @skill def stop_following(self) -> str: @@ -169,7 +195,9 @@ def _on_pointcloud(self, pointcloud: PointCloud2) -> None: with self._lock: self._latest_pointcloud = pointcloud - def _follow_person(self, query: str, initial_bbox: BBox) -> str: + def _follow_person( + self, query: str, initial_bbox: BBox, detection_image: Image | None = None + ) -> str: x1, y1, x2, y2 = initial_bbox box = np.array([x1, y1, x2, y2], dtype=np.float32) @@ -184,8 +212,11 @@ def _follow_person(self, query: str, initial_bbox: BBox) -> str: if latest_image is None: return "No image available to start tracking." + # Use the detection frame for tracker init when available, so the bbox + # matches the image it was computed on. + init_image = detection_image if detection_image is not None else latest_image initial_detections = tracker.init_track( - image=latest_image, + image=init_image, box=box, obj_id=1, ) @@ -199,11 +230,21 @@ def _follow_person(self, query: str, initial_bbox: BBox) -> str: self._thread = Thread(target=self._follow_loop, args=(tracker, query), daemon=True) self._thread.start() - return ( + message = ( "Found the person. Starting to follow. You can stop following by calling " "the 'stop_following' tool." ) + if self._patrolling_module_spec.is_patrolling(): + message += ( + " Note: since the robot was patrolling, this has been stopped automatically " + "(the equivalent of calling the `stop_patrol` tool call) so you don't have " + "to do it. " + ) + self._patrolling_module_spec.stop_patrol() + + return message + def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None: lost_count = 0 period = 1.0 / self._frequency @@ -267,6 +308,11 @@ def _send_stop_reason(self, query: str, reason: str) -> None: logger.info("Person follow stopped", query=query, reason=reason) +def _decode_base64_image(b64: str) -> Image: + bgr_array = TurboJPEG().decode(base64.b64decode(b64)) + return Image(data=bgr_array, format=ImageFormat.BGR) + + person_follow_skill = PersonFollowSkillContainer.blueprint __all__ = ["PersonFollowSkillContainer", "person_follow_skill"] diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 60072ae7fd..0b070dabd9 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -17,7 +17,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict -from dimos.mapping.occupancy.path_map import NavigationStrategy from dimos.models.vl.create import VlModelName ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] @@ -47,7 +46,7 @@ class GlobalConfig(BaseSettings): robot_model: str | None = None robot_width: float = 0.3 robot_rotation_diameter: float = 0.6 - planner_strategy: NavigationStrategy = "simple" + nerf_speed: float = 1.0 planner_robot_speed: float | None = None mcp_port: int = 9990 mcp_host: str = "0.0.0.0" diff --git a/dimos/e2e_tests/conftest.py b/dimos/e2e_tests/conftest.py index 12f4a674a6..4509a7e5e4 100644 --- a/dimos/e2e_tests/conftest.py +++ b/dimos/e2e_tests/conftest.py @@ -26,6 +26,7 @@ from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.std_msgs.Bool import Bool +from dimos.simulation.mujoco.direct_cmd_vel_explorer import DirectCmdVelExplorer from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher @@ -116,3 +117,11 @@ def run_person_track() -> None: thread.join(timeout=1.0) if publisher is not None: publisher.stop() + + +@pytest.fixture +def direct_cmd_vel_explorer() -> Generator[PersonTrackPublisher, None, None]: + explorer = DirectCmdVelExplorer() + explorer.start() + yield explorer + explorer.stop() diff --git a/dimos/e2e_tests/test_patrol_and_follow.py b/dimos/e2e_tests/test_patrol_and_follow.py new file mode 100644 index 0000000000..642f044aa3 --- /dev/null +++ b/dimos/e2e_tests/test_patrol_and_follow.py @@ -0,0 +1,86 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import time + +import pytest + +from dimos.e2e_tests.conf_types import StartPersonTrack +from dimos.e2e_tests.dimos_cli_call import DimosCliCall +from dimos.e2e_tests.lcm_spy import LcmSpy +from dimos.simulation.mujoco.direct_cmd_vel_explorer import DirectCmdVelExplorer + +points = [ + (0, -7.07), + (-4.16, -7.07), + (-4.45, 1.10), + (-6.72, 2.87), + (-1.78, 3.01), + (-1.54, 5.74), + (3.88, 6.16), + (2.16, 9.36), + (4.70, 3.87), + (4.67, -7.15), + (4.57, -4.19), + (-0.84, -2.78), + (-4.71, 1.17), + (4.30, 0.87), +] + + +@pytest.mark.skipif_in_ci +@pytest.mark.skipif_no_openai +@pytest.mark.mujoco +def test_patrol_and_follow( + lcm_spy: LcmSpy, + start_blueprint: Callable[[str], DimosCliCall], + human_input: Callable[[str], None], + start_person_track: StartPersonTrack, + direct_cmd_vel_explorer: DirectCmdVelExplorer, +) -> None: + start_blueprint( + "--mujoco-start-pos", + "-10.75 -6.78", + "--nerf-speed", + "0.5", + "run", + "--disable", + "spatial-memory", + "unitree-go2-agentic", + ) + + lcm_spy.save_topic("/rpc/Agent/on_system_modules/res") + lcm_spy.wait_for_saved_topic("/rpc/Agent/on_system_modules/res", timeout=120.0) + + time.sleep(5) + + print("Starting discovery.") + + # Explore the entire room by driving directly via /cmd_vel. + direct_cmd_vel_explorer.follow_points(points) + + print("Ended discovery.") + + start_person_track( + [ + (-10.75, -6.78), + (0, -7.07), + ] + ) + human_input( + "patrol around until you find a man wearing beige pants and when you do, start following him" + ) + + time.sleep(120) diff --git a/dimos/mapping/occupancy/gradient.py b/dimos/mapping/occupancy/gradient.py index c9db43088e..c74f0b5b61 100644 --- a/dimos/mapping/occupancy/gradient.py +++ b/dimos/mapping/occupancy/gradient.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Literal, TypeAlias, cast + import numpy as np from scipy import ndimage # type: ignore[import-untyped] from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +if TYPE_CHECKING: + from numpy.typing import NDArray + + +GradientStrategy: TypeAlias = Literal["gradient", "voronoi"] + def gradient( occupancy_grid: OccupancyGrid, obstacle_threshold: int = 50, max_distance: float = 2.0 @@ -50,7 +58,7 @@ def gradient( # Compute distance transform (distance to nearest obstacle in cells) # Unknown cells are treated as if they don't exist for distance calculation - distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) + distance_cells = cast("NDArray[np.float64]", ndimage.distance_transform_edt(1 - obstacle_map)) # Convert to meters and clip to max distance distance_meters = np.clip(distance_cells * occupancy_grid.resolution, 0, max_distance) # type: ignore[operator] diff --git a/dimos/mapping/occupancy/path_map.py b/dimos/mapping/occupancy/path_map.py index a99a423de8..8920c6e30b 100644 --- a/dimos/mapping/occupancy/path_map.py +++ b/dimos/mapping/occupancy/path_map.py @@ -14,7 +14,7 @@ from typing import Literal, TypeAlias -from dimos.mapping.occupancy.gradient import voronoi_gradient +from dimos.mapping.occupancy.gradient import GradientStrategy, gradient, voronoi_gradient from dimos.mapping.occupancy.inflation import simple_inflate from dimos.mapping.occupancy.operations import overlay_occupied, smooth_occupied from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid @@ -23,7 +23,10 @@ def make_navigation_map( - occupancy_grid: OccupancyGrid, robot_width: float, strategy: NavigationStrategy + occupancy_grid: OccupancyGrid, + robot_width: float, + strategy: NavigationStrategy, + gradient_strategy: GradientStrategy, ) -> OccupancyGrid: half_width = robot_width / 2 gradient_distance = 1.5 @@ -37,4 +40,9 @@ def make_navigation_map( else: raise ValueError(f"Unknown strategy: {strategy}") - return voronoi_gradient(costmap, max_distance=gradient_distance) + if gradient_strategy == "gradient": + return gradient(costmap, max_distance=gradient_distance) + elif gradient_strategy == "voronoi": + return voronoi_gradient(costmap, max_distance=gradient_distance) + else: + raise ValueError(f"Unknown gradient strategy: {gradient_strategy}") diff --git a/dimos/mapping/occupancy/test_path_map.py b/dimos/mapping/occupancy/test_path_map.py index b3e250db9d..8928e1ab92 100644 --- a/dimos/mapping/occupancy/test_path_map.py +++ b/dimos/mapping/occupancy/test_path_map.py @@ -28,7 +28,7 @@ def test_make_navigation_map(occupancy, strategy) -> None: expected = cv2.imread(get_data(f"make_navigation_map_{strategy}.png"), cv2.IMREAD_COLOR) robot_width = 0.4 - og = make_navigation_map(occupancy, robot_width, strategy=strategy) + og = make_navigation_map(occupancy, robot_width, strategy=strategy, gradient_strategy="voronoi") result = visualize_occupancy_grid(og, "rainbow") np.testing.assert_array_equal(result.data, expected) diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index e9744f6d81..61b06d5efd 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from sam2.sam2_video_predictor import SAM2VideoPredictor -os.environ['TQDM_DISABLE'] = '1' logger = setup_logger() @@ -88,6 +87,10 @@ def __init__( self._predictor = instantiate(cfg.model, _recursive_=True) + # Suppress the per-frame "propagate in video" tqdm bar from sam2 + import sam2.sam2_video_predictor as _svp + _svp.tqdm = lambda iterable, *a, **kw: iterable + ckpt_path = str(get_data("models_edgetam") / "edgetam.pt") sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] diff --git a/dimos/navigation/patrolling/create_patrol_router.py b/dimos/navigation/patrolling/create_patrol_router.py new file mode 100644 index 0000000000..b26e8b4edf --- /dev/null +++ b/dimos/navigation/patrolling/create_patrol_router.py @@ -0,0 +1,45 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from dimos.navigation.patrolling.routers.patrol_router import PatrolRouter + +PatrolRouterName = Literal["random", "coverage", "frontier"] + + +def create_patrol_router(name: PatrolRouterName, clearance_radius_m: float) -> PatrolRouter: + match name: + case "random": + # Inline to avoid unnecessary imports. + from dimos.navigation.patrolling.routers.random_patrol_router import RandomPatrolRouter + + return RandomPatrolRouter(clearance_radius_m) + case "coverage": + # Inline to avoid unnecessary imports. + from dimos.navigation.patrolling.routers.coverage_patrol_router import ( + CoveragePatrolRouter, + ) + + return CoveragePatrolRouter(clearance_radius_m) + case "frontier": + # Inline to avoid unnecessary imports. + from dimos.navigation.patrolling.routers.frontier_patrol_router import ( + FrontierPatrolRouter, + ) + + return FrontierPatrolRouter(clearance_radius_m) diff --git a/dimos/navigation/patrolling/module.py b/dimos/navigation/patrolling/module.py new file mode 100644 index 0000000000..48ee59699b --- /dev/null +++ b/dimos/navigation/patrolling/module.py @@ -0,0 +1,143 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading + +from dimos_lcm.std_msgs import Bool +from reactivex.disposable import Disposable + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.navigation.patrolling.create_patrol_router import create_patrol_router +from dimos.navigation.patrolling.routers.patrol_router import PatrolRouter +from dimos.navigation.replanning_a_star.module_spec import ReplanningAStarPlannerSpec +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PatrollingModule(Module): + odom: In[PoseStamped] + global_costmap: In[OccupancyGrid] + goal_reached: In[Bool] + goal_request: Out[PoseStamped] + + _global_config: GlobalConfig + _router: PatrolRouter + _planner_spec: ReplanningAStarPlannerSpec + + _clearance_multiplier = 0.5 + + def __init__(self, g: GlobalConfig = global_config) -> None: + super().__init__() + self._global_config = g + clearance_radius_m = self._global_config.robot_width * self._clearance_multiplier + self._router = create_patrol_router("coverage", clearance_radius_m) + + self._patrol_lock = threading.RLock() + self._patrol_thread: threading.Thread | None = None + self._stop_event = threading.Event() + self._goal_reached_event = threading.Event() + self._goal_or_stop_event = threading.Event() + self._latest_pose: PoseStamped | None = None + + @rpc + def start(self) -> None: + super().start() + + self._disposables.add(Disposable(self.odom.subscribe(self._on_odom))) + self._disposables.add( + Disposable(self.global_costmap.subscribe(self._router.handle_occupancy_grid)) + ) + self._disposables.add(Disposable(self.goal_reached.subscribe(self._on_goal_reached))) + + @rpc + def stop(self) -> None: + self._stop_patrolling() + super().stop() + + @skill + def start_patrol(self) -> str: + """Start patrolling the known area. The robot will continuously pick patrol goals from the router and navigate to them until `stop_patrol` is called.""" + self._router.reset() + + with self._patrol_lock: + if self._patrol_thread is not None and self._patrol_thread.is_alive(): + return "Patrol is already running. Use `stop_patrol` to stop." + self._planner_spec.set_replanning_enabled(False) + self._planner_spec.set_safe_goal_clearance( + self._global_config.robot_rotation_diameter / 2 + 0.2 + ) + self._stop_event.clear() + self._patrol_thread = threading.Thread( + target=self._patrol_loop, daemon=True, name=self.__class__.__name__ + ) + self._patrol_thread.start() + return "Patrol started. Use `stop_patrol` to stop." + + @rpc + def is_patrolling(self) -> bool: + with self._patrol_lock: + return self._patrol_thread is not None and self._patrol_thread.is_alive() + + @skill + def stop_patrol(self) -> str: + """Stop the ongoing patrol.""" + self._stop_patrolling() + return "Patrol stopped." + + def _on_odom(self, msg: PoseStamped) -> None: + self._latest_pose = msg + self._router.handle_odom(msg) + + def _on_goal_reached(self, _msg: Bool) -> None: + self._goal_reached_event.set() + self._goal_or_stop_event.set() + + def _patrol_loop(self) -> None: + while not self._stop_event.is_set(): + goal = self._router.next_goal() + if goal is None: + logger.info("No patrol goal available, retrying in 2s") + if self._stop_event.wait(timeout=2.0): + break + continue + + self._goal_reached_event.clear() + self.goal_request.publish(goal) + + # Wait until goal is reached or stop is requested. + self._goal_or_stop_event.wait() + self._goal_or_stop_event.clear() + + def _stop_patrolling(self) -> None: + self._stop_event.set() + self._goal_or_stop_event.set() + self._planner_spec.set_replanning_enabled(True) + self._planner_spec.reset_safe_goal_clearance() + + # Publish current position as goal to cancel in-progress navigation. + pose = self._latest_pose + if pose is not None: + self.goal_request.publish(pose) + with self._patrol_lock: + if self._patrol_thread is not None: + self._patrol_thread.join() + self._patrol_thread = None diff --git a/dimos/navigation/patrolling/patrolling_module_spec.py b/dimos/navigation/patrolling/patrolling_module_spec.py new file mode 100644 index 0000000000..23dffeec16 --- /dev/null +++ b/dimos/navigation/patrolling/patrolling_module_spec.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Protocol + +from dimos.spec.utils import Spec +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PatrollingModuleSpec(Spec, Protocol): + def start_patrol(self) -> str: ... + def is_patrolling(self) -> bool: ... + def stop_patrol(self) -> str: ... diff --git a/dimos/navigation/patrolling/routers/base_patrol_router.py b/dimos/navigation/patrolling/routers/base_patrol_router.py new file mode 100644 index 0000000000..eef90d81ed --- /dev/null +++ b/dimos/navigation/patrolling/routers/base_patrol_router.py @@ -0,0 +1,78 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from threading import RLock +import time + +import numpy as np +from numpy.typing import NDArray + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.navigation.patrolling.routers.visitation_history import VisitationHistory + + +class BasePatrolRouter(ABC): + _occupancy_grid_min_update_interval_s = 60.0 + _occupancy_grid: OccupancyGrid | None + _occupancy_grid_updated_at: float + _pose: PoseStamped | None + _lock: RLock + _clearance_radius_m: float + + def __init__(self, clearance_radius_m: float) -> None: + self._occupancy_grid = None + self._occupancy_grid_updated_at = 0.0 + self._visitation = VisitationHistory(clearance_radius_m) + self._pose = None + self._lock = RLock() + self._clearance_radius_m = clearance_radius_m + + @property + def _visited(self) -> NDArray[np.bool_] | None: + return self._visitation.visited + + def handle_occupancy_grid(self, msg: OccupancyGrid) -> None: + with self._lock: + now = time.monotonic() + if ( + self._occupancy_grid is not None + and now - self._occupancy_grid_updated_at + < self._occupancy_grid_min_update_interval_s + ): + return + self._occupancy_grid = msg + self._occupancy_grid_updated_at = now + self._visitation.update_grid(msg) + + def handle_odom(self, msg: PoseStamped) -> None: + with self._lock: + self._pose = msg + if self._occupancy_grid is None: + return + self._visitation.handle_odom(msg.position.x, msg.position.y) + + def get_saturation(self) -> float: + with self._lock: + return self._visitation.get_saturation() + + def reset(self) -> None: + with self._lock: + self._occupancy_grid = None + self._occupancy_grid_updated_at = 0.0 + self._visitation.reset() + + @abstractmethod + def next_goal(self) -> PoseStamped | None: ... diff --git a/dimos/navigation/patrolling/routers/coverage_patrol_router.py b/dimos/navigation/patrolling/routers/coverage_patrol_router.py new file mode 100644 index 0000000000..a060868d9d --- /dev/null +++ b/dimos/navigation/patrolling/routers/coverage_patrol_router.py @@ -0,0 +1,164 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy.typing import NDArray +from scipy.ndimage import binary_erosion + +from dimos.mapping.occupancy.gradient import gradient, voronoi_gradient +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path +from dimos.navigation.patrolling.routers.base_patrol_router import BasePatrolRouter +from dimos.navigation.patrolling.utilities import point_to_pose_stamped +from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar + + +class CoveragePatrolRouter(BasePatrolRouter): + _costmap: OccupancyGrid | None + _safe_mask: NDArray[np.bool_] | None + _sampling_weights: NDArray[np.float64] | None + _candidates_to_consider: int = 7 + + def __init__(self, clearance_radius_m: float) -> None: + super().__init__(clearance_radius_m) + self._costmap = None + self._safe_mask = None + self._sampling_weights = None + + def handle_occupancy_grid(self, msg: OccupancyGrid) -> None: + with self._lock: + prev = self._occupancy_grid + super().handle_occupancy_grid(msg) + if self._occupancy_grid is prev: + # Throttled — no update happened. + return + self._costmap = gradient(msg, max_distance=1.5) + + # Precompute the safe mask (cells with enough clearance from obstacles). + clearance_cells = self._visitation.clearance_radius_cells + free_mask = msg.grid == 0 + structure = np.ones((2 * clearance_cells + 1, 2 * clearance_cells + 1), dtype=bool) + self._safe_mask = binary_erosion(free_mask, structure=structure).astype(bool) + + # Precompute voronoi-based sampling weights so candidates are spread + # across different corridors/regions rather than clustering in large + # open areas. Low voronoi cost = on the skeleton (equidistant from + # walls) = high sampling weight. + voronoi = voronoi_gradient(msg, max_distance=1.5) + voronoi_cost = voronoi.grid.astype(np.float64) + # Invert: skeleton cells (cost 0) become weight 100, walls (100) become 0. + # Clamp negatives (unknown = -1) to 0. + weights = np.clip(100.0 - voronoi_cost, 0.0, 100.0) + self._sampling_weights = weights + + def next_goal(self) -> PoseStamped | None: + with self._lock: + if ( + self._occupancy_grid is None + or self._visited is None + or self._safe_mask is None + or self._costmap is None + or self._sampling_weights is None + ): + return None + occupancy_grid = self._occupancy_grid + costmap = self._costmap + safe_mask = self._safe_mask + sampling_weights = self._sampling_weights + visited = self._visited.copy() + pose = self._pose + + if pose is None: + return None + + start = (pose.position.x, pose.position.y) + + # Get candidate points from unvisited safe cells. + unvisited_safe = safe_mask & ~visited + if not np.any(unvisited_safe): + # Fall back to all safe cells if everything visited. + unvisited_safe = safe_mask + if not np.any(unvisited_safe): + return None + + safe_indices = np.argwhere(unvisited_safe) + n_candidates = min(self._candidates_to_consider, len(safe_indices)) + + # Weight candidates by voronoi score so they spread across corridors + # rather than clustering in large open areas. + weights = sampling_weights[safe_indices[:, 0], safe_indices[:, 1]] + weight_sum = weights.sum() + if weight_sum > 0: + probs = weights / weight_sum + else: + probs = None + chosen = safe_indices[ + np.random.choice(len(safe_indices), size=n_candidates, replace=False, p=probs) + ] + + best_score = -1 + best_point = None + + for row, col in chosen: + world = occupancy_grid.grid_to_world((col, row, 0)) + candidate = (world.x, world.y) + + path = min_cost_astar(costmap, candidate, start, unknown_penalty=1.0, use_cpp=True) + if path is None: + continue + + # Count how many new (unvisited) cells would be covered along this path. + new_cells = self._count_new_coverage(path, visited, occupancy_grid, safe_mask) + if new_cells > best_score: + best_score = new_cells + best_point = candidate + + if best_point is None: + return None + return point_to_pose_stamped(best_point) + + def _count_new_coverage( + self, + path: Path, + visited: NDArray[np.bool_], + occupancy_grid: OccupancyGrid, + safe_mask: NDArray[np.bool_], + ) -> int: + r = self._visitation.clearance_radius_cells + h, w = visited.shape + covered = np.zeros_like(visited) + + # Sample every few poses to avoid redundant work on dense paths. + step = max(1, r) + poses = path.poses[::step] + + for pose in poses: + grid = occupancy_grid.world_to_grid((pose.position.x, pose.position.y)) + col, row = int(grid.x), int(grid.y) + r_min = max(0, row - r) + r_max = min(h, row + r + 1) + c_min = max(0, col - r) + c_max = min(w, col + r + 1) + d_r_min = r_min - (row - r) + d_r_max = d_r_min + (r_max - r_min) + d_c_min = c_min - (col - r) + d_c_max = d_c_min + (c_max - c_min) + covered[r_min:r_max, c_min:c_max] |= self._visitation.clearance_disk[ + d_r_min:d_r_max, d_c_min:d_c_max + ] + + # New coverage = cells in covered that are not yet visited and are free space. + new = covered & ~visited & safe_mask + return int(np.count_nonzero(new)) diff --git a/dimos/navigation/patrolling/routers/frontier_patrol_router.py b/dimos/navigation/patrolling/routers/frontier_patrol_router.py new file mode 100644 index 0000000000..ed1ec18dca --- /dev/null +++ b/dimos/navigation/patrolling/routers/frontier_patrol_router.py @@ -0,0 +1,125 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +from numpy.typing import NDArray +from scipy.ndimage import binary_erosion, label + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.navigation.patrolling.routers.base_patrol_router import BasePatrolRouter +from dimos.navigation.patrolling.utilities import point_to_pose_stamped + + +class FrontierPatrolRouter(BasePatrolRouter): + """Patrol router that picks goals based on unvisited frontier clusters. + + This router: + 1. Finds connected components of unvisited safe cells. + 2. Scores each component by size / euclidean_distance from the robot. + 3. Within the best component, picks the point farthest from the robot + to create long sweeping paths through unvisited territory. + """ + + _safe_mask: NDArray[np.bool_] | None + _min_cluster_cells: int = 20 + + def __init__(self, clearance_radius_m: float) -> None: + super().__init__(clearance_radius_m) + self._safe_mask = None + + def handle_occupancy_grid(self, msg: OccupancyGrid) -> None: + with self._lock: + prev = self._occupancy_grid + super().handle_occupancy_grid(msg) + if self._occupancy_grid is prev: + return + + clearance_cells = self._visitation.clearance_radius_cells + free_mask = msg.grid == 0 + structure = np.ones((2 * clearance_cells + 1, 2 * clearance_cells + 1), dtype=bool) + self._safe_mask = binary_erosion(free_mask, structure=structure).astype(bool) + + def next_goal(self) -> PoseStamped | None: + with self._lock: + if self._occupancy_grid is None or self._visited is None or self._safe_mask is None: + return None + occupancy_grid = self._occupancy_grid + safe_mask = self._safe_mask + visited = self._visited.copy() + pose = self._pose + + if pose is None: + return None + + # Robot position in grid coordinates. + grid_pos = occupancy_grid.world_to_grid((pose.position.x, pose.position.y)) + robot_col, robot_row = grid_pos.x, grid_pos.y + + # Unvisited safe cells. + unvisited_safe = safe_mask & ~visited + if not np.any(unvisited_safe): + unvisited_safe = safe_mask + if not np.any(unvisited_safe): + return None + + # Find connected components of unvisited safe space. + labeled, n_components = label(unvisited_safe) + if n_components == 0: + return None + + # Compute size and centroid of each component using vectorized ops. + component_ids = np.arange(1, n_components + 1) + rows, cols = np.where(labeled > 0) + labels_flat = labeled[rows, cols] + + # Size and centroid of each component. + sizes = np.bincount(labels_flat, minlength=n_components + 1)[1:] + sum_rows = np.bincount(labels_flat, weights=rows, minlength=n_components + 1)[1:] + sum_cols = np.bincount(labels_flat, weights=cols, minlength=n_components + 1)[1:] + + # Filter out tiny clusters. + valid = sizes >= self._min_cluster_cells + if not np.any(valid): + valid = sizes > 0 + + valid_ids = component_ids[valid] + valid_sizes = sizes[valid].astype(np.float64) + + # Euclidean distance from robot to each cluster centroid. + centroid_rows = sum_rows[valid] / valid_sizes + centroid_cols = sum_cols[valid] / valid_sizes + dr = centroid_rows - robot_row + dc = centroid_cols - robot_col + distances = np.maximum(np.sqrt(dr * dr + dc * dc), 1.0) + + # Score: prefer large, nearby clusters. + scores = valid_sizes / distances + + best_idx = int(np.argmax(scores)) + + # Within the best cluster, pick the point farthest from the robot. + # This creates long sweeping paths through unvisited territory instead + # of tiny movements toward a barely-shifting centroid. + cluster_mask = labeled == valid_ids[best_idx] + cluster_indices = np.argwhere(cluster_mask) + cluster_dr: NDArray[np.floating[Any]] = cluster_indices[:, 0] - robot_row + cluster_dc: NDArray[np.floating[Any]] = cluster_indices[:, 1] - robot_col + dists_sq = cluster_dr * cluster_dr + cluster_dc * cluster_dc + goal_row, goal_col = cluster_indices[np.argmax(dists_sq)] + + world = occupancy_grid.grid_to_world((int(goal_col), int(goal_row), 0)) + return point_to_pose_stamped((world.x, world.y)) diff --git a/dimos/navigation/patrolling/routers/patrol_router.py b/dimos/navigation/patrolling/routers/patrol_router.py new file mode 100644 index 0000000000..19aff4ca34 --- /dev/null +++ b/dimos/navigation/patrolling/routers/patrol_router.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid + + +class PatrolRouter(Protocol): + def __init__(self, clearance_radius_m: float) -> None: ... + def handle_occupancy_grid(self, msg: OccupancyGrid) -> None: ... + def handle_odom(self, msg: PoseStamped) -> None: ... + def next_goal(self) -> PoseStamped | None: ... + def get_saturation(self) -> float: ... + def reset(self) -> None: ... diff --git a/dimos/navigation/patrolling/routers/random_patrol_router.py b/dimos/navigation/patrolling/routers/random_patrol_router.py new file mode 100644 index 0000000000..67e6f8e25c --- /dev/null +++ b/dimos/navigation/patrolling/routers/random_patrol_router.py @@ -0,0 +1,68 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy.typing import NDArray +from scipy.ndimage import binary_erosion + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.navigation.patrolling.routers.base_patrol_router import BasePatrolRouter +from dimos.navigation.patrolling.utilities import point_to_pose_stamped + + +class RandomPatrolRouter(BasePatrolRouter): + def next_goal(self) -> PoseStamped | None: + with self._lock: + if self._occupancy_grid is None or self._visited is None: + return None + occupancy_grid = self._occupancy_grid + visited = self._visited.copy() + point = _random_empty_spot( + occupancy_grid, clearance_m=self._clearance_radius_m, visited=visited + ) + if point is None: + return None + return point_to_pose_stamped(point) + + +def _random_empty_spot( + occupancy_grid: OccupancyGrid, + clearance_m: float, + visited: NDArray[np.bool_] | None = None, +) -> tuple[float, float] | None: + clearance_cells = int(np.ceil(clearance_m / occupancy_grid.resolution)) + + free_mask = occupancy_grid.grid == 0 + if not np.any(free_mask): + return None + + # Erode the free mask by the clearance radius so only cells with full clearance remain. + structure = np.ones((2 * clearance_cells + 1, 2 * clearance_cells + 1), dtype=bool) + safe_mask = binary_erosion(free_mask, structure=structure) + + # Prefer unvisited cells; fall back to all safe cells if everything is visited. + if visited is not None: + unvisited_safe = safe_mask & ~visited + if np.any(unvisited_safe): + safe_mask = unvisited_safe + + safe_indices = np.argwhere(safe_mask) + if len(safe_indices) == 0: + return None + + idx = safe_indices[np.random.randint(len(safe_indices))] + row, col = idx + world = occupancy_grid.grid_to_world((col, row, 0)) + return (world.x, world.y) diff --git a/dimos/navigation/patrolling/routers/visitation_history.py b/dimos/navigation/patrolling/routers/visitation_history.py new file mode 100644 index 0000000000..939da19ddb --- /dev/null +++ b/dimos/navigation/patrolling/routers/visitation_history.py @@ -0,0 +1,121 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy.typing import NDArray + +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid + + +def _circular_disk(radius_cells: int) -> NDArray[np.bool_]: + y, x = np.ogrid[-radius_cells : radius_cells + 1, -radius_cells : radius_cells + 1] + return np.asarray((x * x + y * y) <= radius_cells * radius_cells) + + +class VisitationHistory: + """Tracks visited locations in world coordinates, independent of occupancy grid changes. + + When a new occupancy grid arrives, the visited mask is rebuilt from stored + world-coordinate points. To avoid unbounded growth, when the visited + saturation reaches ``_saturation_threshold`` the oldest half of the stored + points are discarded and the mask is rebuilt. + """ + + _saturation_threshold = 0.50 + _min_distance_m = 0.05 + + def __init__(self, clearance_radius_m: float) -> None: + self._points: list[tuple[float, float]] = [] + self._visited: NDArray[np.bool_] | None = None + self._grid: OccupancyGrid | None = None + self._clearance_radius_m = clearance_radius_m + self._clearance_radius_cells: int = 0 + self._clearance_disk: NDArray[np.bool_] = np.ones((1, 1), dtype=bool) + + @property + def visited(self) -> NDArray[np.bool_] | None: + return self._visited + + @property + def clearance_radius_cells(self) -> int: + return self._clearance_radius_cells + + @property + def clearance_disk(self) -> NDArray[np.bool_]: + return self._clearance_disk + + def update_grid(self, grid: OccupancyGrid) -> None: + self._grid = grid + self._clearance_radius_cells = int(np.ceil(self._clearance_radius_m / grid.resolution)) + self._clearance_disk = _circular_disk(self._clearance_radius_cells) + self._rebuild() + + def handle_odom(self, x: float, y: float) -> None: + if self._points: + lx, ly = self._points[-1] + if (x - lx) ** 2 + (y - ly) ** 2 < self._min_distance_m**2: + return + self._points.append((x, y)) + if self._visited is None or self._grid is None: + return + self._stamp(x, y) + if self.get_saturation() >= self._saturation_threshold: + n = len(self._points) + self._points = self._points[n // 2 :] + self._rebuild() + + def get_saturation(self) -> float: + grid = self._grid + visited = self._visited + if grid is None or visited is None: + return 0.0 + free_mask = grid.grid == 0 + total = int(np.count_nonzero(free_mask)) + if total == 0: + return 0.0 + visited_free = int(np.count_nonzero(visited & free_mask)) + return visited_free / total + + def reset(self) -> None: + self._points.clear() + self._visited = None + self._grid = None + + def _rebuild(self) -> None: + grid = self._grid + if grid is None: + return + self._visited = np.zeros((grid.height, grid.width), dtype=bool) + for x, y in self._points: + self._stamp(x, y) + + def _stamp(self, x: float, y: float) -> None: + grid = self._grid + visited = self._visited + if grid is None or visited is None: + return + r = self._clearance_radius_cells + grid_pos = grid.world_to_grid((x, y)) + col, row = int(grid_pos.x), int(grid_pos.y) + if row + r < 0 or row - r >= grid.height or col + r < 0 or col - r >= grid.width: + return + r_min = max(0, row - r) + r_max = min(grid.height, row + r + 1) + c_min = max(0, col - r) + c_max = min(grid.width, col + r + 1) + d_r_min = r_min - (row - r) + d_r_max = d_r_min + (r_max - r_min) + d_c_min = c_min - (col - r) + d_c_max = d_c_min + (c_max - c_min) + visited[r_min:r_max, c_min:c_max] |= self._clearance_disk[d_r_min:d_r_max, d_c_min:d_c_max] diff --git a/dimos/navigation/patrolling/test_create_patrol_router.py b/dimos/navigation/patrolling/test_create_patrol_router.py new file mode 100644 index 0000000000..abc59db94e --- /dev/null +++ b/dimos/navigation/patrolling/test_create_patrol_router.py @@ -0,0 +1,101 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import cv2 +import numpy as np +import pytest + +from dimos.mapping.occupancy.gradient import gradient +from dimos.mapping.occupancy.path_resampling import smooth_resample_path +from dimos.mapping.occupancy.visualizations import visualize_occupancy_grid +from dimos.mapping.pointclouds.occupancy import height_cost_occupancy +from dimos.mapping.pointclouds.util import read_pointcloud +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.patrolling.create_patrol_router import create_patrol_router +from dimos.navigation.patrolling.utilities import point_to_pose_stamped +from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar +from dimos.utils.data import get_data + + +@pytest.fixture +def big_office() -> OccupancyGrid: + data = read_pointcloud(get_data("big_office.ply")) + cloud = PointCloud2.from_numpy(np.asarray(data.points), frame_id="") + return height_cost_occupancy(cloud) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "router_name, saturation", [("random", 0.20), ("coverage", 0.30), ("frontier", 0.20)] +) +def test_patrolling_coverage(router_name, saturation, big_office) -> None: + start = (-1.03, -13.48) + robot_width = 0.4 + multiplier = 1.5 + big_office_gradient = gradient(big_office, max_distance=1.5) + router = create_patrol_router(router_name, robot_width * multiplier) + router.handle_occupancy_grid(big_office) + router.handle_odom(point_to_pose_stamped(start)) + + all_poses: list = [] + for _ in range(15): + goal = router.next_goal() + if goal is None: + continue + path = min_cost_astar( + big_office_gradient, goal.position, start, unknown_penalty=1.0, use_cpp=True + ) + if path is None: + continue + path = smooth_resample_path(path, goal, 0.1) + for pose in path.poses: + router.handle_odom(pose) + all_poses.append(pose) + start = (path.poses[-1].position.x, path.poses[-1].position.y) + + assert router.get_saturation() > saturation + + if os.environ.get("DEBUG"): + _save_coverage_image(router_name, router, all_poses, big_office, big_office_gradient) + + +def _save_coverage_image(router_name, router, all_poses, big_office, big_office_gradient) -> None: + image = visualize_occupancy_grid(big_office_gradient, "rainbow") + h, w = image.data.shape[:2] + visit_counts = np.zeros((h, w), dtype=np.float32) + radius = int(np.ceil(router._clearance_radius_m / big_office.resolution)) + stamp = np.zeros((h, w), dtype=np.uint8) + + for pose in all_poses: + grid = big_office.world_to_grid((pose.position.x, pose.position.y)) + gx, gy = int(grid.x), int(grid.y) + if 0 <= gy < h and 0 <= gx < w: + stamp[:] = 0 + cv2.circle(stamp, (gx, gy), radius, 1, -1) + visit_counts += stamp + + alpha = 0.05 + mask = visit_counts > 0 + blend = 1.0 - (1.0 - alpha) ** visit_counts + + overlay = image.data.astype(np.float32) * 0.24 + for c in range(3): + overlay[:, :, c][mask] = overlay[:, :, c][mask] * (1.0 - blend[mask]) + 255.0 * blend[mask] + + image.data = overlay.astype(np.uint8) + image.save(f"patrolling_coverage_{router_name}.png") diff --git a/dimos/navigation/patrolling/utilities.py b/dimos/navigation/patrolling/utilities.py new file mode 100644 index 0000000000..d7caffaa9c --- /dev/null +++ b/dimos/navigation/patrolling/utilities.py @@ -0,0 +1,26 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + +def point_to_pose_stamped(point: tuple[float, float]) -> PoseStamped: + pose = PoseStamped() + pose.position.x = point[0] + pose.position.y = point[1] + return pose + + +def pose_stamped_to_point(pose: PoseStamped) -> tuple[float, float]: + return (pose.position.x, pose.position.y) diff --git a/dimos/navigation/replanning_a_star/controllers.py b/dimos/navigation/replanning_a_star/controllers.py index 07ba8c7119..57b2031cf4 100644 --- a/dimos/navigation/replanning_a_star/controllers.py +++ b/dimos/navigation/replanning_a_star/controllers.py @@ -104,12 +104,12 @@ def _apply_min_velocity(self, velocity: float, min_velocity: float) -> float: return velocity def _angular_twist(self, angular_velocity: float) -> Twist: - # In simulation, add a small forward velocity to help the locomotion - # policy execute rotation (some policies don't handle pure in-place rotation). - linear_x = 0.18 if self._global_config.simulation else 0.0 + # In simulation, we need stroger values + if self._global_config.simulation and abs(angular_velocity) < 0.8: + angular_velocity = 0.8 * np.sign(angular_velocity) return Twist( - linear=Vector3(linear_x, 0.0, 0.0), + linear=Vector3(0.0, 0.0, 0.0), angular=Vector3(0.0, 0.0, angular_velocity), ) diff --git a/dimos/navigation/replanning_a_star/global_planner.py b/dimos/navigation/replanning_a_star/global_planner.py index 4c4e79cb7b..50fe0aa1f1 100644 --- a/dimos/navigation/replanning_a_star/global_planner.py +++ b/dimos/navigation/replanning_a_star/global_planner.py @@ -52,6 +52,7 @@ class GlobalPlanner(Resource): _global_config: GlobalConfig _navigation_map: NavigationMap + _navigation_map_near: NavigationMap _local_planner: LocalPlanner _position_tracker: PositionTracker _replan_limiter: ReplanLimiter @@ -60,31 +61,40 @@ class GlobalPlanner(Resource): _replan_event: Event _replan_reason: StopMessage | None _lock: RLock + _safe_goal_clearance: float _safe_goal_tolerance: float = 4.0 _goal_tolerance: float = 0.2 _rotation_tolerance: float = math.radians(15) _replan_goal_tolerance: float = 0.5 - _max_replan_attempts: int = 10 _stuck_time_window: float = 8.0 + _stuck_threshold: float = 0.4 _max_path_deviation: float = 0.9 + _replanning_enabled: bool = True def __init__(self, global_config: GlobalConfig) -> None: self.path = Subject() self.goal_reached = Subject() self._global_config = global_config - self._navigation_map = NavigationMap(self._global_config) + self._navigation_map = NavigationMap(self._global_config, "voronoi") + self._navigation_map_near = NavigationMap(self._global_config, "gradient") self._local_planner = LocalPlanner( self._global_config, self._navigation_map, self._goal_tolerance ) - self._position_tracker = PositionTracker(self._stuck_time_window) + + stuck_threshold = self._stuck_threshold + if global_config.simulation: + stuck_threshold = 1.0 + + self._position_tracker = PositionTracker(self._stuck_time_window, stuck_threshold) self._replan_limiter = ReplanLimiter() self._disposables = CompositeDisposable() self._stop_planner = Event() self._replan_event = Event() self._replan_reason = None self._lock = RLock() + self._reset_safe_goal_clearance() def start(self) -> None: self._local_planner.start() @@ -117,6 +127,7 @@ def handle_odom(self, msg: PoseStamped) -> None: def handle_global_costmap(self, msg: OccupancyGrid) -> None: self._navigation_map.update(msg) + self._navigation_map_near.update(msg) def handle_goal_request(self, goal: PoseStamped) -> None: logger.info("Got new goal", goal=str(goal)) @@ -126,6 +137,13 @@ def handle_goal_request(self, goal: PoseStamped) -> None: self._replan_limiter.reset() self._plan_path() + def set_safe_goal_clearance(self, clearance: float) -> None: + with self._lock: + self._safe_goal_clearance = clearance + + def reset_safe_goal_clearance(self) -> None: + self._reset_safe_goal_clearance() + def cancel_goal(self, *, but_will_try_again: bool = False, arrived: bool = False) -> None: logger.info("Cancelling goal.", but_will_try_again=but_will_try_again, arrived=arrived) @@ -143,6 +161,10 @@ def cancel_goal(self, *, but_will_try_again: bool = False, arrived: bool = False if not but_will_try_again: self.goal_reached.on_next(Bool(arrived)) + def set_replanning_enabled(self, enabled: bool) -> None: + with self._lock: + self._replanning_enabled = enabled + def get_state(self) -> NavigationState: return self._local_planner.get_state() @@ -267,6 +289,10 @@ def _replan_path(self) -> None: self.cancel_goal(arrived=True) return + if not self._replanning_enabled: + self.cancel_goal() + return + if not self._replan_limiter.can_retry(current_odom.position): self.cancel_goal() return @@ -291,6 +317,10 @@ def _plan_path(self) -> None: safe_goal = self._find_safe_goal(current_goal.position) if not safe_goal: + logger.warning( + "No safe goal found.", x=round(current_goal.x, 3), y=round(current_goal.y, 3) + ) + self.cancel_goal() return path = self._find_wide_path(safe_goal, current_odom.position) @@ -299,6 +329,7 @@ def _plan_path(self) -> None: logger.warning( "No path found to the goal.", x=round(safe_goal.x, 3), y=round(safe_goal.y, 3) ) + self.cancel_goal() return resampled_path = smooth_resample_path(path, current_goal, 0.1) @@ -312,7 +343,9 @@ def _find_wide_path(self, goal: Vector3, robot_pos: Vector3) -> Path | None: sizes_to_try: list[float] = [1.1] for size in sizes_to_try: - costmap = self._navigation_map.make_gradient_costmap(size) + distance = robot_pos.distance(goal) + navigation_map = self._navigation_map if distance > 1.5 else self._navigation_map_near + costmap = navigation_map.make_gradient_costmap(size) path = min_cost_astar(costmap, goal, robot_pos) if path and path.poses: logger.info(f"Found path {size}x robot width.") @@ -331,7 +364,7 @@ def _find_safe_goal(self, goal: Vector3) -> Vector3 | None: goal, algorithm="bfs_contiguous", cost_threshold=CostValues.OCCUPIED, - min_clearance=self._global_config.robot_rotation_diameter / 2, + min_clearance=self._safe_goal_clearance, max_search_distance=self._safe_goal_tolerance, ) @@ -346,3 +379,7 @@ def _find_safe_goal(self, goal: Vector3) -> Vector3 | None: logger.info("Found safe goal.", x=round(safe_goal.x, 2), y=round(safe_goal.y, 2)) return safe_goal + + def _reset_safe_goal_clearance(self) -> None: + with self._lock: + self._safe_goal_clearance = self._global_config.robot_rotation_diameter / 2 diff --git a/dimos/navigation/replanning_a_star/local_planner.py b/dimos/navigation/replanning_a_star/local_planner.py index d50d0def84..fd408692db 100644 --- a/dimos/navigation/replanning_a_star/local_planner.py +++ b/dimos/navigation/replanning_a_star/local_planner.py @@ -86,9 +86,13 @@ def __init__( self._navigation_map = navigation_map self._goal_tolerance = goal_tolerance + speed = self._speed + if global_config.nerf_speed < 1.0: + speed *= global_config.nerf_speed + self._controller = PController( self._global_config, - self._speed, + speed, self._control_frequency, ) diff --git a/dimos/navigation/replanning_a_star/min_cost_astar.py b/dimos/navigation/replanning_a_star/min_cost_astar.py index 55f502680c..2855a1ecdf 100644 --- a/dimos/navigation/replanning_a_star/min_cost_astar.py +++ b/dimos/navigation/replanning_a_star/min_cost_astar.py @@ -198,8 +198,9 @@ def min_cost_astar( continue if neighbor_val == CostValues.UNKNOWN: - # Unknown cells have a moderate traversal cost cell_cost = cost_threshold * unknown_penalty + if cell_cost >= cost_threshold: + continue elif neighbor_val == CostValues.FREE: cell_cost = 0.0 else: diff --git a/dimos/navigation/replanning_a_star/min_cost_astar_cpp.cpp b/dimos/navigation/replanning_a_star/min_cost_astar_cpp.cpp index f19b3bf826..5b4a575197 100644 --- a/dimos/navigation/replanning_a_star/min_cost_astar_cpp.cpp +++ b/dimos/navigation/replanning_a_star/min_cost_astar_cpp.cpp @@ -201,8 +201,10 @@ std::vector> min_cost_astar_cpp( double cell_cost; if (val == COST_UNKNOWN) { - // Unknown cells have a moderate traversal cost cell_cost = cost_threshold * unknown_penalty; + if (cell_cost >= cost_threshold) { + continue; + } } else if (val == COST_FREE) { cell_cost = 0.0; } else { diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 796390f06c..842a6319d4 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -108,6 +108,18 @@ def cancel_goal(self) -> bool: self._planner.cancel_goal() return True + @rpc + def set_replanning_enabled(self, enabled: bool) -> None: + self._planner.set_replanning_enabled(enabled) + + @rpc + def set_safe_goal_clearance(self, clearance: float) -> None: + self._planner.set_safe_goal_clearance(clearance) + + @rpc + def reset_safe_goal_clearance(self) -> None: + self._planner.reset_safe_goal_clearance() + replanning_a_star_planner = ReplanningAStarPlanner.blueprint diff --git a/dimos/navigation/replanning_a_star/module_spec.py b/dimos/navigation/replanning_a_star/module_spec.py new file mode 100644 index 0000000000..c9ec73dc47 --- /dev/null +++ b/dimos/navigation/replanning_a_star/module_spec.py @@ -0,0 +1,29 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.navigation.base import NavigationState +from dimos.spec.utils import Spec + + +class ReplanningAStarPlannerSpec(Spec, Protocol): + def set_goal(self, goal: PoseStamped) -> bool: ... + def get_state(self) -> NavigationState: ... + def is_goal_reached(self) -> bool: ... + def cancel_goal(self) -> bool: ... + def set_replanning_enabled(self, enabled: bool) -> None: ... + def set_safe_goal_clearance(self, clearance: float) -> None: ... + def reset_safe_goal_clearance(self) -> None: ... diff --git a/dimos/navigation/replanning_a_star/navigation_map.py b/dimos/navigation/replanning_a_star/navigation_map.py index f1c149ded6..fde75c0b0e 100644 --- a/dimos/navigation/replanning_a_star/navigation_map.py +++ b/dimos/navigation/replanning_a_star/navigation_map.py @@ -15,17 +15,20 @@ from threading import RLock from dimos.core.global_config import GlobalConfig +from dimos.mapping.occupancy.gradient import GradientStrategy from dimos.mapping.occupancy.path_map import make_navigation_map from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid class NavigationMap: _global_config: GlobalConfig + _gradient_strategy: GradientStrategy _binary: OccupancyGrid | None = None _lock: RLock - def __init__(self, global_config: GlobalConfig) -> None: + def __init__(self, global_config: GlobalConfig, gradient_strategy: GradientStrategy) -> None: self._global_config = global_config + self._gradient_strategy = gradient_strategy self._lock = RLock() def update(self, occupancy_grid: OccupancyGrid) -> None: @@ -62,5 +65,6 @@ def make_gradient_costmap(self, robot_increase: float = 1.0) -> OccupancyGrid: return make_navigation_map( binary, self._global_config.robot_width * robot_increase, - strategy=self._global_config.planner_strategy, + strategy="simple", + gradient_strategy=self._gradient_strategy, ) diff --git a/dimos/navigation/replanning_a_star/position_tracker.py b/dimos/navigation/replanning_a_star/position_tracker.py index 77b4df0dd0..7d8249b562 100644 --- a/dimos/navigation/replanning_a_star/position_tracker.py +++ b/dimos/navigation/replanning_a_star/position_tracker.py @@ -34,10 +34,10 @@ class PositionTracker: _index: int _size: int - def __init__(self, time_window: float) -> None: + def __init__(self, time_window: float, threshold: float) -> None: self._lock = RLock() self._time_window = time_window - self._threshold = 0.4 + self._threshold = threshold self._max_points = int(_max_points_per_second * self._time_window) self.reset_data() diff --git a/dimos/navigation/replanning_a_star/test_min_cost_astar.py b/dimos/navigation/replanning_a_star/test_min_cost_astar.py index 9cc0cad29a..1ea28d1cae 100644 --- a/dimos/navigation/replanning_a_star/test_min_cost_astar.py +++ b/dimos/navigation/replanning_a_star/test_min_cost_astar.py @@ -59,6 +59,54 @@ def test_astar_corner(costmap_three_paths) -> None: np.testing.assert_array_equal(actual.data, expected.data) +def test_astar_unknown_penalty_blocks_unknown_cells(costmap) -> None: + """With unknown_penalty=1.0, unknown cells should be untraversable.""" + # Create a grid with a corridor of free cells and unknown cells surrounding it. + # Place start and goal such that the shortest path would go through unknown cells + # but with penalty=1.0 it should either avoid them or return None. + grid = np.full((100, 100), -1, dtype=np.int8) # All unknown + # Carve a U-shaped free corridor: left column, bottom row, right column + grid[10:90, 10] = 0 # left column + grid[89, 10:90] = 0 # bottom row + grid[10:90, 89] = 0 # right column + og = OccupancyGrid(grid, resolution=0.1) + + start = og.grid_to_world((10, 10)) + goal = og.grid_to_world((89, 10)) + + for use_cpp in [False, True]: + path = min_cost_astar(og, goal, start, unknown_penalty=1.0, use_cpp=use_cpp) + if path is None: + # No path through unknown is also acceptable + continue + # Verify no path cell lands on an unknown cell + for pose in path.poses: + gp = og.world_to_grid((pose.position.x, pose.position.y)) + gx, gy = round(gp.x), round(gp.y) + if 0 <= gx < 100 and 0 <= gy < 100: + assert grid[gy, gx] != -1, ( + f"Path traverses unknown cell at grid ({gx}, {gy}), use_cpp={use_cpp}" + ) + + +def test_astar_unknown_penalty_allows_with_low_penalty(costmap) -> None: + """With unknown_penalty < 1.0, unknown cells should be traversable.""" + grid = np.full((50, 50), -1, dtype=np.int8) # All unknown + grid[5, 5] = 0 # start cell free + grid[45, 45] = 0 # goal cell free + og = OccupancyGrid(grid, resolution=0.1) + + start = og.grid_to_world((5, 5)) + goal = og.grid_to_world((45, 45)) + + for use_cpp in [False, True]: + path = min_cost_astar(og, goal, start, unknown_penalty=0.5, use_cpp=use_cpp) + assert path is not None, ( + f"Should find path through unknown with penalty=0.5, use_cpp={use_cpp}" + ) + assert len(path.poses) > 0 + + def test_astar_python_and_cpp(costmap) -> None: start = Vector3(4.0, 2.0, 0) goal = Vector3(6.15, 10.0) diff --git a/dimos/perception/detection/type/detection2d/test_bbox.py b/dimos/perception/detection/type/detection2d/test_bbox.py index 5a76b41601..66795d7782 100644 --- a/dimos/perception/detection/type/detection2d/test_bbox.py +++ b/dimos/perception/detection/type/detection2d/test_bbox.py @@ -14,6 +14,7 @@ import pytest +@pytest.mark.skipif_in_ci def test_detection2d(detection2d) -> None: # def test_detection_basic_properties(detection2d): """Test basic detection properties.""" diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index 81df107ecf..fc1895373c 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -20,7 +20,7 @@ import os import threading import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock, create_autospec, patch from dotenv import load_dotenv @@ -42,7 +42,6 @@ ) from dimos.perception.experimental.temporal_memory.temporal_memory import ( TemporalMemory, - TemporalMemoryConfig, ) from dimos.perception.experimental.temporal_memory.temporal_state import TemporalState from dimos.perception.experimental.temporal_memory.temporal_utils.graph_utils import ( @@ -522,8 +521,8 @@ class VideoReplayModule(Module): video_out: Out[Image] - def __init__(self, num_frames: int = 5) -> None: - super().__init__() + def __init__(self, num_frames: int = 5, **kwargs: Any) -> None: + super().__init__(**kwargs) self.num_frames = num_frames @rpc @@ -596,14 +595,12 @@ def temporal_memory_module(self, dimos_cluster, tmp_path): tm = dimos_cluster.deploy( TemporalMemory, vlm=vlm, - config=TemporalMemoryConfig( - fps=1.0, - window_s=2.0, - stride_s=2.0, - summary_interval_s=10.0, - max_frames_per_window=3, - db_dir=str(db_dir), - ), + fps=1.0, + window_s=2.0, + stride_s=2.0, + summary_interval_s=10.0, + max_frames_per_window=3, + db_dir=str(db_dir), ) yield tm try: diff --git a/dimos/perception/perceive_loop_skill.py b/dimos/perception/perceive_loop_skill.py index 4532e61c2e..d18526147f 100644 --- a/dimos/perception/perceive_loop_skill.py +++ b/dimos/perception/perceive_loop_skill.py @@ -14,10 +14,13 @@ from __future__ import annotations +from datetime import datetime, timezone import json +import os from threading import RLock from typing import TYPE_CHECKING, Any +import cv2 from langchain_core.messages import HumanMessage from dimos.agents.agent import AgentSpec @@ -33,6 +36,9 @@ if TYPE_CHECKING: from reactivex.abc import DisposableBase + from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D + logger = setup_logger() @@ -41,12 +47,13 @@ class PerceiveLoopSkill(Module): color_image: In[Image] _agent_spec: AgentSpec - _period: float = 0.5 # seconds - how often to run the perceive loop + _period: float = 0.1 # seconds - how often to run the perceive loop def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._vl_model = create(self.config.g.detection_model) self._active_lookout: tuple[str, ...] = () + self._then: dict[str, Any] | None = None self._lookout_subscription: DisposableBase | None = None self._model_started: bool = False self._lock = RLock() @@ -61,13 +68,40 @@ def stop(self) -> None: super().stop() @skill - def look_out_for(self, description_of_things: list[str]) -> str: + def look_out_for( + self, description_of_things: list[str], then: dict[str, Any] | None = None + ) -> str: """This tool will continuously look out for things matching the description in the input list, and notify the agent whenever it finds a match. After the match is found, it will stop. You can ask it for `look_out_for(["small dogs", "cats"])` and you will be notified back whenever such a detection is made. + + Optionally, you can specify a `then` parameter to automatically execute + another tool when a match is found, without waiting for the agent to + process the notification. This is useful for time-sensitive actions like + following a detected person. + + The `then` parameter is a dict with: + - "tool": name of the tool to call (e.g. "follow_person") + - "args": dict of arguments to pass to the tool + + In the args, you can use template variables that will be replaced with + detection data: + - "$bbox": the bounding box [x1, y1, x2, y2] of the best detection + - "$label": the label/name of the detection + - "$image": base64-encoded JPEG of the frame the detection was made on + + Example: + look_out_for(["person"], then={ + "tool": "follow_person", + "args": { + "query": "person", + "initial_bbox": "$bbox", + "initial_image": "$image", + } + }) """ with self._lock: @@ -83,6 +117,7 @@ def look_out_for(self, description_of_things: list[str]) -> str: self._vl_model.start() self._model_started = True self._active_lookout = tuple(description_of_things) + self._then = then self._lookout_subscription = sharpest.subscribe( on_next=self._on_image, on_error=lambda e: logger.exception("Error in perceive loop", exc_info=e), @@ -114,6 +149,9 @@ def _on_image(self, image: Image) -> None: if not detections: return + if os.environ.get("DEBUG"): + _write_debug_image(image, detections) + with self._lock: if not self._active_lookout: return @@ -121,12 +159,30 @@ def _on_image(self, image: Image) -> None: self._lookout_subscription.dispose() self._lookout_subscription = None self._active_lookout = () + then = self._then + self._then = None self._vl_model.stop() self._model_started = False - self._agent_spec.add_message( - HumanMessage(f"Found a match for {active_lookout_str}. Please announce audibly.") + if then is None: + self._agent_spec.add_message( + HumanMessage(f"Found a match for {active_lookout_str}. Please announce audibly.") + ) + return + + best = max(detections.detections, key=lambda d: d.bbox_2d_volume()) + continuation_context: dict[str, Any] = { + "bbox": list(best.bbox), + "label": best.name, + "image": image.to_base64(quality=70), + } + logger.info( + "Lookout matched, dispatching continuation", + lookout=active_lookout_str, + continuation=then, + detection=continuation_context, ) + self._agent_spec.dispatch_continuation(then, continuation_context) def _stop_lookout(self) -> None: with self._lock: @@ -134,6 +190,28 @@ def _stop_lookout(self) -> None: self._lookout_subscription.dispose() self._lookout_subscription = None self._active_lookout = () + self._then = None if self._model_started: self._vl_model.stop() self._model_started = False + + +def _write_debug_image(image: Image, detections: ImageDetections2D[Detection2DBBox]) -> None: + try: + debug_img = image.to_opencv().copy() + for det in detections.detections: + x1, y1, x2, y2 = (int(v) for v in det.bbox) + cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText( + debug_img, + det.name, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + ts = datetime.now(tz=timezone.utc).isoformat().replace(":", "-") + cv2.imwrite(f"debug-{ts}.ignore.jpg", debug_img) + except Exception: + pass # Ignore debug drawing errors diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index d8567036bf..2827be6a32 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -15,6 +15,7 @@ import asyncio import os import time +from typing import Any import pytest from reactivex import operators as ops @@ -76,7 +77,7 @@ def stop(self) -> None: class OdometryReplayModule(Module): """Module that replays odometry data and publishes to the tf system.""" - def __init__(self, odom_path: str) -> None: + def __init__(self, odom_path: str, **kwargs: Any) -> None: super().__init__() self.odom_path = odom_path self._subscription = None @@ -134,11 +135,11 @@ async def test_spatial_memory_module_with_replay(dimos, tmp_path): # Deploy modules # Video replay module - video_module = dimos.deploy(VideoReplayModule, video_path) + video_module = dimos.deploy(VideoReplayModule, video_path=video_path) video_module.video_out.transport = LCMTransport("/test_video", Image) # Odometry replay module (publishes to tf system directly) - odom_module = dimos.deploy(OdometryReplayModule, odom_path) + odom_module = dimos.deploy(OdometryReplayModule, odom_path=odom_path) # Spatial memory module spatial_memory = dimos.deploy( diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 80e6ec701a..10dd290e2d 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -19,6 +19,7 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( wavefront_frontier_explorer, ) +from dimos.navigation.patrolling.module import PatrollingModule from dimos.navigation.replanning_a_star.module import replanning_a_star_planner from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic @@ -28,6 +29,7 @@ cost_mapper(), replanning_a_star_planner(), wavefront_frontier_explorer(), + PatrollingModule.blueprint(), ).global_config(n_workers=7, robot_model="unitree_go2") __all__ = ["unitree_go2"] diff --git a/dimos/simulation/mujoco/direct_cmd_vel_explorer.py b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py new file mode 100644 index 0000000000..58dc91f6b1 --- /dev/null +++ b/dimos/simulation/mujoco/direct_cmd_vel_explorer.py @@ -0,0 +1,107 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import threading +from typing import TYPE_CHECKING + +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +if TYPE_CHECKING: + from collections.abc import Callable + + +class DirectCmdVelExplorer: + def __init__( + self, + linear_speed: float = 0.8, + rotation_speed: float = 1.5, + publish_rate: float = 10.0, + ) -> None: + self.linear_speed = linear_speed + self.rotation_speed = rotation_speed + self._dt = 1.0 / publish_rate + self._cmd_vel: LCMTransport[Twist] | None = None + self._odom: LCMTransport[PoseStamped] | None = None + self._pose: PoseStamped | None = None + self._new_pose = threading.Event() + self._unsub: Callable[[], None] | None = None + + def start(self) -> None: + self._cmd_vel = LCMTransport("/cmd_vel", Twist) + self._odom = LCMTransport("/odom", PoseStamped) + self._pose = None + self._unsub = self._odom.subscribe(self._on_odom) # type: ignore[func-returns-value] + + def stop(self) -> None: + if self._unsub: + self._unsub() + if self._cmd_vel: + self._cmd_vel.stop() + if self._odom: + self._odom.stop() + + def _on_odom(self, msg: PoseStamped) -> None: + self._pose = msg + self._new_pose.set() + + def _wait_for_pose(self) -> PoseStamped: + self._new_pose.clear() + self._new_pose.wait(timeout=5.0) + assert self._pose is not None, "No odom received" + return self._pose + + @staticmethod + def _normalize_angle(angle: float) -> float: + while angle > math.pi: + angle -= 2 * math.pi + while angle < -math.pi: + angle += 2 * math.pi + return angle + + def _stop(self) -> None: + assert self._cmd_vel is not None + self._cmd_vel.broadcast(None, Twist(linear=Vector3(), angular=Vector3())) + + def _drive_to(self, target_x: float, target_y: float) -> None: + """Pursuit controller: steer toward the target while driving forward.""" + while True: + pose = self._wait_for_pose() + dx = target_x - pose.x + dy = target_y - pose.y + distance = math.hypot(dx, dy) + if distance < 0.3: + break + target_heading = math.atan2(dy, dx) + heading_error = self._normalize_angle(target_heading - pose.yaw) + # Only drive forward when roughly facing the target. + if abs(heading_error) > 0.3: + linear = 0.0 + else: + linear = self.linear_speed + angular = max(-self.rotation_speed, min(self.rotation_speed, heading_error * 2.0)) + assert self._cmd_vel is not None + self._cmd_vel.broadcast( + None, + Twist(linear=Vector3(linear, 0, 0), angular=Vector3(0, 0, angular)), + ) + self._stop() + + def follow_points(self, waypoints: list[tuple[float, float]]) -> None: + self._wait_for_pose() + for tx, ty in waypoints: + self._drive_to(tx, ty) diff --git a/docs/capabilities/navigation/native/assets/coverage.png b/docs/capabilities/navigation/native/assets/coverage.png new file mode 100644 index 0000000000..2ad2112071 --- /dev/null +++ b/docs/capabilities/navigation/native/assets/coverage.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c5ef9943e14c2d02fa2e19032ffeb2fc79f927c903e552e7c0db01b858f5297 +size 256502 diff --git a/docs/capabilities/navigation/native/assets/frontier.png b/docs/capabilities/navigation/native/assets/frontier.png new file mode 100644 index 0000000000..97089338f5 --- /dev/null +++ b/docs/capabilities/navigation/native/assets/frontier.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f2e35b3a6cc1e82667958f6bb3120a5f0bb5bba99f156df7283b774559168b5 +size 251903 diff --git a/docs/capabilities/navigation/native/assets/patrol_path.png b/docs/capabilities/navigation/native/assets/patrol_path.png new file mode 100644 index 0000000000..4d53c29409 --- /dev/null +++ b/docs/capabilities/navigation/native/assets/patrol_path.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cecf773affedca3d14d781e956d20ec9b396df53b5473e41fb7a182d700bef2 +size 476239 diff --git a/docs/capabilities/navigation/native/assets/random.png b/docs/capabilities/navigation/native/assets/random.png new file mode 100644 index 0000000000..b407034eb6 --- /dev/null +++ b/docs/capabilities/navigation/native/assets/random.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18ab48a549d02d1cd63c8c21b4294acc9c235e8c8f704e2c6ee71d0399ca4aa0 +size 260526 diff --git a/docs/capabilities/navigation/native/index.md b/docs/capabilities/navigation/native/index.md index 6a8c5224e9..1c9ba4a0da 100644 --- a/docs/capabilities/navigation/native/index.md +++ b/docs/capabilities/navigation/native/index.md @@ -116,6 +116,42 @@ All visualization layers shown together ![All layers](assets/5-all.png) +## Patrolling + +The patrolling system drives the robot to systematically cover a **known** area. It is exposed as an agent skill. An LLM agent can call `start_patrol` and `stop_patrol` to control it. Note that the area has to be explored first. + +### How it works + +1. **Visitation tracking** — As the robot moves, a visitation grid (aligned to the costmap) marks cells around the robot's position as visited. This gives the system a running picture of where the robot has and hasn't been. This expires over time, and has to be visited again. + +2. **Goal selection** — A *patrol router* picks the next goal. The default strategy is **coverage**: it samples a handful of candidate points from unvisited, obstacle-free cells, plans a path to each one, and picks the candidate whose path would cover the most new ground. Candidates are weighted by a Voronoi skeleton so goals are more likely to be spread evenly across the map, rather than clustering in large open areas. + +3. **Navigation loop** — The module sends each goal to the planner and waits for a `goal_reached` signal before requesting the next one. If no valid goal is available (e.g. the map hasn't loaded yet), it retries after a short delay. + +4. **Stopping** — When patrol is stopped, the module cancels in-progress navigation by publishing the robot's current pose as the goal, then re-enables the planner's normal replanning behavior. + +### Patrol router strategies + +| Router | Behavior | +|--------------|------------------------------------------------------------------------------------------------| +| `coverage` | Maximizes new-cell coverage per goal. Uses Voronoi weighting for even spatial distribution. | +| `random` | Picks a random unvisited, obstacle-free cell. | +| `frontier` | Targets the boundary between known and unknown space, useful for exploration-style patrol. | + +### Safety + +Goal candidates are filtered through a **safe mask** — the free-space region eroded by the robot's clearance radius — so the robot is never sent to a position too close to walls or obstacles. The planner's safe-goal clearance is also tightened while patrolling to ensure the robot can rotate in place at every goal. + +### Router comparison + +| Coverage | Frontier | Random | +|----------|----------|--------| +| ![coverage](assets/coverage.png) | ![frontier](assets/frontier.png) | ![random](assets/random.png) | + +### Sample patrol trace (26 min) + +![Patrol path](assets/patrol_path.png) + ## Blueprint Composition The navigation stack is composed in the [`unitree_go2`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py) blueprint: diff --git a/misc/optimize_patrol/optimize_candidates.py b/misc/optimize_patrol/optimize_candidates.py new file mode 100644 index 0000000000..3732b5baef --- /dev/null +++ b/misc/optimize_patrol/optimize_candidates.py @@ -0,0 +1,126 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Find the optimal _candidates_to_consider value for CoveragePatrolRouter. + +For each candidate count, runs until TARGET_COVERAGE is reached and measures: + - Average next_goal() call duration (the planning cost) + - Distance traveled to reach the target coverage (path quality) + +Produces a single dual-axis chart with both metrics. +""" + +from concurrent.futures import ProcessPoolExecutor, as_completed +import json +import subprocess +import sys + +import matplotlib.pyplot as plt +import numpy as np + +CANDIDATE_VALUES = list(range(1, 16)) +N_ITERATIONS = 9 +TARGET_COVERAGE = 0.25 +MAX_WORKERS = 32 + + +def run_child(candidates: int) -> tuple[int, float, float]: + result = subprocess.run( + [ + sys.executable, + "-m", + "misc.optimize_patrol.optimize_candidates_child", + "--candidates", + str(candidates), + "--target_coverage", + str(TARGET_COVERAGE), + "--n_iterations", + str(N_ITERATIONS), + ], + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"FAILED candidates={candidates}", file=sys.stderr) + print(result.stderr, file=sys.stderr) + return (candidates, float("nan"), float("nan")) + data = json.loads(result.stdout.strip()) + return (candidates, data["avg_next_goal_time"], data["distance"]) + + +def main() -> None: + print(f"Sweeping candidates_to_consider in {CANDIDATE_VALUES}") + print( + f" {N_ITERATIONS} iterations each, target coverage={TARGET_COVERAGE}, up to {MAX_WORKERS} workers" + ) + + results: dict[int, tuple[float, float]] = {} + + with ProcessPoolExecutor(max_workers=MAX_WORKERS) as pool: + futures = {pool.submit(run_child, c): c for c in CANDIDATE_VALUES} + for i, future in enumerate(as_completed(futures), 1): + cand, avg_time, distance = future.result() + results[cand] = (avg_time, distance) + print( + f"[{i}/{len(CANDIDATE_VALUES)}] candidates={cand}" + f" -> avg_next_goal={avg_time * 1000:.1f}ms distance={distance:.0f}m" + ) + + xs = sorted(results.keys()) + avg_times_ms = np.array([results[x][0] * 1000 for x in xs]) # Convert to ms. + distances = np.array([results[x][1] for x in xs]) + + fig, ax1 = plt.subplots(figsize=(9, 5)) + color_time = "#FF5722" + color_dist = "#2196F3" + + ax1.set_xlabel("candidates_to_consider") + ax1.set_ylabel("Avg next_goal() duration (ms)", color=color_time) + ax1.plot( + xs, + avg_times_ms, + "s-", + color=color_time, + linewidth=2, + markersize=6, + label="Avg planning time", + ) + ax1.tick_params(axis="y", labelcolor=color_time) + ax1.set_xticks(xs) + + ax2 = ax1.twinx() + ax2.set_ylabel("Distance to reach target (m)", color=color_dist) + ax2.plot(xs, distances, "o-", color=color_dist, linewidth=2, markersize=6, label="Distance") + ax2.tick_params(axis="y", labelcolor=color_dist) + + fig.suptitle( + f"Planning cost vs path quality to reach {TARGET_COVERAGE:.0%} coverage" + f" (median of {N_ITERATIONS} iters)" + ) + fig.tight_layout() + out = "candidates_optimization.png" + fig.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nChart saved to {out}") + plt.close(fig) + + # Summary table. + print("\n--- Summary ---") + print(f"{'candidates':>12} {'avg_time(ms)':>14} {'distance(m)':>14}") + for x in xs: + avg_t, dist = results[x] + print(f"{x:>12} {avg_t * 1000:>14.1f} {dist:>14.0f}") + + +if __name__ == "__main__": + main() diff --git a/misc/optimize_patrol/optimize_candidates_child.py b/misc/optimize_patrol/optimize_candidates_child.py new file mode 100644 index 0000000000..f62e5d2d22 --- /dev/null +++ b/misc/optimize_patrol/optimize_candidates_child.py @@ -0,0 +1,173 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Child process: test a single _candidates_to_consider value. + +Runs until target coverage is reached. Outputs JSON: + {"avg_next_goal_time": float, "distance": float} +""" + +import argparse +import json +import math +import time + +import numpy as np + +from dimos.mapping.occupancy.gradient import gradient +from dimos.mapping.occupancy.path_resampling import smooth_resample_path +from dimos.mapping.pointclouds.occupancy import height_cost_occupancy +from dimos.mapping.pointclouds.util import read_pointcloud +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.patrolling.create_patrol_router import create_patrol_router +from dimos.navigation.patrolling.routers.coverage_patrol_router import CoveragePatrolRouter +from dimos.navigation.patrolling.utilities import point_to_pose_stamped +from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar +from dimos.utils.data import get_data + +SCORING_STAMP_RADIUS_M = 0.2 +CLEARANCE_RADIUS_M = 0.2 +MAX_DISTANCE = 50_000.0 # Safety cap to avoid infinite loops. + + +def _circular_disk(radius_cells: int) -> np.ndarray: + y, x = np.ogrid[-radius_cells : radius_cells + 1, -radius_cells : radius_cells + 1] + return (x * x + y * y) <= radius_cells * radius_cells + + +def _stamp_scoring_map( + visited: np.ndarray, x: float, y: float, occupancy_grid, radius_cells: int, disk: np.ndarray +) -> None: + grid_pos = occupancy_grid.world_to_grid((x, y)) + col, row = int(grid_pos.x), int(grid_pos.y) + h, w = visited.shape + r = radius_cells + if row + r < 0 or row - r >= h or col + r < 0 or col - r >= w: + return + r_min = max(0, row - r) + r_max = min(h, row + r + 1) + c_min = max(0, col - r) + c_max = min(w, col + r + 1) + d_r_min = r_min - (row - r) + d_r_max = d_r_min + (r_max - r_min) + d_c_min = c_min - (col - r) + d_c_max = d_c_min + (c_max - c_min) + visited[r_min:r_max, c_min:c_max] |= disk[d_r_min:d_r_max, d_c_min:d_c_max] + + +def run_iteration( + candidates_to_consider: int, + target_coverage: float, + occupancy_grid, + costmap, + scoring_radius_cells: int, + scoring_disk: np.ndarray, +) -> tuple[float, float]: + """Returns (avg_next_goal_time_seconds, distance_traveled).""" + start = (-1.03, -13.48) + + router = create_patrol_router("coverage", CLEARANCE_RADIUS_M) + assert isinstance(router, CoveragePatrolRouter) + router._candidates_to_consider = candidates_to_consider + router.handle_occupancy_grid(occupancy_grid) + router.handle_odom(point_to_pose_stamped(start)) + + h, w = occupancy_grid.height, occupancy_grid.width + scoring_visited = np.zeros((h, w), dtype=bool) + free_mask = occupancy_grid.grid == 0 + total_free = int(np.count_nonzero(free_mask)) + if total_free == 0: + return 0.0, 0.0 + + _stamp_scoring_map( + scoring_visited, start[0], start[1], occupancy_grid, scoring_radius_cells, scoring_disk + ) + + distance_walked = 0.0 + next_goal_times: list[float] = [] + + while distance_walked < MAX_DISTANCE: + t0 = time.perf_counter() + goal = router.next_goal() + next_goal_times.append(time.perf_counter() - t0) + + if goal is None: + break + path = min_cost_astar(costmap, goal.position, start, unknown_penalty=1.0, use_cpp=True) + if path is None: + continue + path = smooth_resample_path(path, goal, 0.1) + + for pose in path.poses: + dx = pose.position.x - start[0] + dy = pose.position.y - start[1] + distance_walked += math.sqrt(dx * dx + dy * dy) + start = (pose.position.x, pose.position.y) + + router.handle_odom(pose) + _stamp_scoring_map( + scoring_visited, + pose.position.x, + pose.position.y, + occupancy_grid, + scoring_radius_cells, + scoring_disk, + ) + + coverage = int(np.count_nonzero(scoring_visited & free_mask)) / total_free + if coverage >= target_coverage: + break + + avg_time = float(np.mean(next_goal_times)) if next_goal_times else 0.0 + return avg_time, distance_walked + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--candidates", type=int, required=True) + parser.add_argument("--target_coverage", type=float, default=0.3) + parser.add_argument("--n_iterations", type=int, default=3) + args = parser.parse_args() + + data = read_pointcloud(get_data("big_office.ply")) + cloud = PointCloud2.from_numpy(np.asarray(data.points), frame_id="") + occupancy_grid = height_cost_occupancy(cloud) + costmap = gradient(occupancy_grid, max_distance=1.5) + + scoring_radius_cells = int(np.ceil(SCORING_STAMP_RADIUS_M / occupancy_grid.resolution)) + scoring_disk = _circular_disk(scoring_radius_cells) + + avg_times = [] + distances = [] + for _ in range(args.n_iterations): + avg_t, dist = run_iteration( + args.candidates, + args.target_coverage, + occupancy_grid, + costmap, + scoring_radius_cells, + scoring_disk, + ) + avg_times.append(avg_t) + distances.append(dist) + + result = { + "avg_next_goal_time": float(np.median(avg_times)), + "distance": float(np.median(distances)), + } + print(json.dumps(result)) + + +if __name__ == "__main__": + main() diff --git a/misc/optimize_patrol/optimize_patrol_router.py b/misc/optimize_patrol/optimize_patrol_router.py new file mode 100644 index 0000000000..1ebec949b8 --- /dev/null +++ b/misc/optimize_patrol/optimize_patrol_router.py @@ -0,0 +1,117 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parent process: matrix-test saturation_threshold and clearance_radius_m.""" + +from concurrent.futures import ProcessPoolExecutor, as_completed +import itertools +import subprocess +import sys + +import matplotlib.pyplot as plt +import numpy as np + +N_POINTS_SAT = 9 +N_POINTS_CLR = 10 +N_ITERATIONS = 5 +TOTAL_DISTANCE = 4000.0 +MAX_WORKERS = 32 + +SAT_MIN, SAT_MAX = 0.1, 0.9 +CLR_MIN, CLR_MAX = 0.1, 1.0 + + +def run_child(saturation_threshold: float, clearance_radius_m: float) -> tuple[float, float, float]: + result = subprocess.run( + [ + sys.executable, + "-m", + "misc.optimize_patrol.optimize_patrol_router_child", + "--saturation_threshold", + str(saturation_threshold), + "--clearance_radius_m", + str(clearance_radius_m), + "--n_iterations", + str(N_ITERATIONS), + "--total_distance", + str(TOTAL_DISTANCE), + ], + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"FAILED sat={saturation_threshold} clr={clearance_radius_m}", file=sys.stderr) + print(result.stderr, file=sys.stderr) + return (saturation_threshold, clearance_radius_m, float("nan")) + score = float(result.stdout.strip()) + return (saturation_threshold, clearance_radius_m, score) + + +def main() -> None: + sat_values = np.linspace(SAT_MIN, SAT_MAX, N_POINTS_SAT) + clr_values = np.linspace(CLR_MIN, CLR_MAX, N_POINTS_CLR) + combos = list(itertools.product(sat_values, clr_values)) + + print(f"Running {len(combos)} combinations with up to {MAX_WORKERS} workers...") + + results: dict[tuple[float, float], float] = {} + + with ProcessPoolExecutor(max_workers=MAX_WORKERS) as pool: + futures = {pool.submit(run_child, sat, clr): (sat, clr) for sat, clr in combos} + for i, future in enumerate(as_completed(futures), 1): + sat, clr, score = future.result() + results[(sat, clr)] = score + print(f"[{i}/{len(combos)}] sat={sat:.3f} clr={clr:.3f} -> score={score:.4f}") + + # Build matrix for plotting. + matrix = np.zeros((N_POINTS_SAT, N_POINTS_CLR)) + for i, sat in enumerate(sat_values): + for j, clr in enumerate(clr_values): + matrix[i, j] = results.get((sat, clr), float("nan")) + + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(matrix, origin="lower", aspect="auto", cmap="viridis") + ax.set_xticks(range(N_POINTS_CLR)) + ax.set_xticklabels([f"{v:.2f}" for v in clr_values]) + ax.set_yticks(range(N_POINTS_SAT)) + ax.set_yticklabels([f"{v:.2f}" for v in sat_values]) + ax.set_xlabel("clearance_radius_m") + ax.set_ylabel("saturation_threshold") + ax.set_title(f"Coverage score (median of {N_ITERATIONS} iters, {TOTAL_DISTANCE}m walk)") + cbar = fig.colorbar(im, ax=ax) + cbar.set_label("Coverage (fraction of free cells visited)") + + # Annotate cells with values. + for i in range(N_POINTS_SAT): + for j in range(N_POINTS_CLR): + val = matrix[i, j] + if not np.isnan(val): + ax.text( + j, + i, + f"{val:.3f}", + ha="center", + va="center", + fontsize=8, + color="white" if val < (np.nanmax(matrix) + np.nanmin(matrix)) / 2 else "black", + ) + + out_path = "patrol_router_optimization.png" + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f"Chart saved to {out_path}") + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/misc/optimize_patrol/optimize_patrol_router_child.py b/misc/optimize_patrol/optimize_patrol_router_child.py new file mode 100644 index 0000000000..1f037818dd --- /dev/null +++ b/misc/optimize_patrol/optimize_patrol_router_child.py @@ -0,0 +1,155 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Child process: test a single (saturation_threshold, clearance_radius_m) pair.""" + +import argparse +import math + +import numpy as np + +from dimos.mapping.occupancy.gradient import gradient +from dimos.mapping.occupancy.path_resampling import smooth_resample_path +from dimos.mapping.pointclouds.occupancy import height_cost_occupancy +from dimos.mapping.pointclouds.util import read_pointcloud +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.navigation.patrolling.create_patrol_router import create_patrol_router +from dimos.navigation.patrolling.routers.visitation_history import VisitationHistory +from dimos.navigation.patrolling.utilities import point_to_pose_stamped +from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar +from dimos.utils.data import get_data + +SCORING_STAMP_RADIUS_M = 0.2 + + +def _circular_disk(radius_cells: int) -> np.ndarray: + y, x = np.ogrid[-radius_cells : radius_cells + 1, -radius_cells : radius_cells + 1] + return (x * x + y * y) <= radius_cells * radius_cells + + +def _stamp_scoring_map( + visited: np.ndarray, x: float, y: float, occupancy_grid, radius_cells: int, disk: np.ndarray +) -> None: + grid_pos = occupancy_grid.world_to_grid((x, y)) + col, row = int(grid_pos.x), int(grid_pos.y) + h, w = visited.shape + r = radius_cells + if row + r < 0 or row - r >= h or col + r < 0 or col - r >= w: + return + r_min = max(0, row - r) + r_max = min(h, row + r + 1) + c_min = max(0, col - r) + c_max = min(w, col + r + 1) + d_r_min = r_min - (row - r) + d_r_max = d_r_min + (r_max - r_min) + d_c_min = c_min - (col - r) + d_c_max = d_c_min + (c_max - c_min) + visited[r_min:r_max, c_min:c_max] |= disk[d_r_min:d_r_max, d_c_min:d_c_max] + + +def run_iteration( + saturation_threshold: float, + clearance_radius_m: float, + total_distance: float, + occupancy_grid, + costmap, + scoring_radius_cells: int, + scoring_disk: np.ndarray, +) -> float: + start = (-1.03, -13.48) + + VisitationHistory._saturation_threshold = saturation_threshold + router = create_patrol_router("coverage", clearance_radius_m) + router.handle_occupancy_grid(occupancy_grid) + router.handle_odom(point_to_pose_stamped(start)) + + h, w = occupancy_grid.height, occupancy_grid.width + scoring_visited = np.zeros((h, w), dtype=bool) + free_mask = occupancy_grid.grid == 0 + + _stamp_scoring_map( + scoring_visited, start[0], start[1], occupancy_grid, scoring_radius_cells, scoring_disk + ) + + distance_walked = 0.0 + + while distance_walked < total_distance: + goal = router.next_goal() + if goal is None: + break + path = min_cost_astar(costmap, goal.position, start, unknown_penalty=1.0, use_cpp=True) + if path is None: + continue + path = smooth_resample_path(path, goal, 0.1) + + for pose in path.poses: + dx = pose.position.x - start[0] + dy = pose.position.y - start[1] + distance_walked += math.sqrt(dx * dx + dy * dy) + start = (pose.position.x, pose.position.y) + + router.handle_odom(pose) + _stamp_scoring_map( + scoring_visited, + pose.position.x, + pose.position.y, + occupancy_grid, + scoring_radius_cells, + scoring_disk, + ) + + if distance_walked >= total_distance: + break + + total_free = int(np.count_nonzero(free_mask)) + if total_free == 0: + return 0.0 + return int(np.count_nonzero(scoring_visited & free_mask)) / total_free + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--saturation_threshold", type=float, required=True) + parser.add_argument("--clearance_radius_m", type=float, required=True) + parser.add_argument("--n_iterations", type=int, default=5) + parser.add_argument("--total_distance", type=float, default=100.0) + args = parser.parse_args() + + data = read_pointcloud(get_data("big_office.ply")) + cloud = PointCloud2.from_numpy(np.asarray(data.points), frame_id="") + occupancy_grid = height_cost_occupancy(cloud) + costmap = gradient(occupancy_grid, max_distance=1.5) + + scoring_radius_cells = int(np.ceil(SCORING_STAMP_RADIUS_M / occupancy_grid.resolution)) + scoring_disk = _circular_disk(scoring_radius_cells) + + scores = [] + for _ in range(args.n_iterations): + score = run_iteration( + args.saturation_threshold, + args.clearance_radius_m, + args.total_distance, + occupancy_grid, + costmap, + scoring_radius_cells, + scoring_disk, + ) + scores.append(score) + + median = float(np.median(scores)) + print(median) + + +if __name__ == "__main__": + main() diff --git a/misc/optimize_patrol/plot_path.py b/misc/optimize_patrol/plot_path.py new file mode 100644 index 0000000000..2c473de747 --- /dev/null +++ b/misc/optimize_patrol/plot_path.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Record robot position from /odom and plot the travel path on Ctrl+C.""" + +from __future__ import annotations + +import argparse +import math +import signal +import time + +import matplotlib.pyplot as plt +import numpy as np + +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + +MIN_DIST = 0.05 # minimum distance (m) between recorded points +STUCK_RADIUS = 0.6 # if robot stays within this radius (m) ... +STUCK_TIMEOUT = 60.0 # ... for this many seconds, stop recording + + +def main() -> None: + parser = argparse.ArgumentParser(description="Record /odom and plot patrol path") + parser.add_argument("--output", "-o", default="patrol_path.ignore.png") + args = parser.parse_args() + + transport: LCMTransport[PoseStamped] = LCMTransport("/odom", PoseStamped) + + xs: list[float] = [] + ys: list[float] = [] + t_start: list[float] = [] # single-element list so closure can mutate + stuck_anchor: list[float] = [0.0, 0.0] # (x, y) center for stuck detection + stuck_since: list[float] = [0.0] # timestamp when robot entered current stuck zone + stop = False + + def on_msg(msg: PoseStamped) -> None: + nonlocal stop + x, y = msg.position.x, msg.position.y + + # Record start time on first message. + if not t_start: + t_start.append(time.time()) + stuck_anchor[0], stuck_anchor[1] = x, y + stuck_since[0] = time.time() + + # Only record if far enough from the last recorded point. + if xs: + dx = x - xs[-1] + dy = y - ys[-1] + if math.hypot(dx, dy) < MIN_DIST: + return + + xs.append(x) + ys.append(y) + + # Stuck detection: check if robot left the stuck circle. + dist_from_anchor = math.hypot(x - stuck_anchor[0], y - stuck_anchor[1]) + if dist_from_anchor > STUCK_RADIUS: + # Robot moved out — reset anchor to current position. + stuck_anchor[0], stuck_anchor[1] = x, y + stuck_since[0] = time.time() + elif time.time() - stuck_since[0] > STUCK_TIMEOUT: + print( + f"\nRobot stuck within {STUCK_RADIUS}m radius for >{STUCK_TIMEOUT:.0f}s — stopping." + ) + stop = True + + transport.start() + transport.subscribe(on_msg) + + print("Listening on /odom ... recording positions. Press Ctrl+C to stop and plot.") + + def _handle_sigint(_sig: int, _frame: object) -> None: + nonlocal stop + stop = True + + signal.signal(signal.SIGINT, _handle_sigint) + + while not stop: + time.sleep(0.05) + + transport.stop() + t_end = time.time() + + # Compute stats. + elapsed = t_end - t_start[0] + mins, secs = divmod(elapsed, 60) + + xs_arr = np.array(xs) + ys_arr = np.array(ys) + dists = np.hypot(np.diff(xs_arr), np.diff(ys_arr)) + total_dist = float(np.sum(dists)) + + print(f"Recorded {len(xs)} points over {int(mins)}m{secs:.0f}s, {total_dist:.1f}m traveled.") + + fig, ax = plt.subplots(figsize=(10, 10)) + + for i in range(len(xs_arr) - 1): + ax.plot(xs_arr[i : i + 2], ys_arr[i : i + 2], color="blue", alpha=0.2, linewidth=2) + + ax.plot(xs_arr[0], ys_arr[0], "go", markersize=10, label="Start") + ax.plot(xs_arr[-1], ys_arr[-1], "ro", markersize=10, label="End") + + ax.set_xlabel("X (m)") + ax.set_ylabel("Y (m)") + ax.set_title(f"Patrol Path — {int(mins)}m{secs:.0f}s, {total_dist:.1f}m traveled") + ax.set_aspect("equal") + ax.legend() + ax.grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(args.output, dpi=150) + print(f"Saved plot to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 1fbd29f86f..1535885edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ dev = [ # Types "lxml-stubs>=0.5.1,<1", "pandas-stubs>=2.3.2.250926,<3", + "scipy-stubs>=1.15.0", "types-PySocks>=1.7.1.20251001,<2", "types-PyYAML>=6.0.12.20250915,<7", "types-colorama>=0.4.15.20250801,<1", @@ -262,18 +263,17 @@ dev = [ "types-gevent>=25.4.0.20250915,<26", "types-greenlet>=3.2.0.20250915,<4", "types-jmespath>=1.0.2.20250809,<2", + "types-requests>=2.32.4.20260107,<3", "types-jsonschema>=4.25.1.20251009,<5", "types-networkx>=3.5.0.20251001,<4", "types-protobuf>=6.32.1.20250918,<7", - "types-psutil>=7.0.0.20251001,<8", + "types-psutil>=7.2.2.20260130,<8", + "types-psycopg2>=2.9.21.20251012", "types-pytz>=2025.2.0.20250809,<2026", "types-simplejson>=3.20.0.20250822,<4", "types-tabulate>=0.9.0.20241207,<1", "types-tensorflow>=2.18.0.20251008,<3", "types-tqdm>=4.67.0.20250809,<5", - "types-psycopg2>=2.9.21.20251012", - "scipy-stubs>=1.15.0", - "types-psutil>=7.2.2.20260130,<8", # Tools "py-spy", @@ -400,6 +400,7 @@ module = [ "plotext", "plum.*", "portal", + "psutil", "pycuda", "pycuda.*", "pydrake", @@ -408,11 +409,14 @@ module = [ "pyzed.*", "rclpy.*", "sam2.*", + "scipy", + "scipy.*", "sensor_msgs.*", "sqlite_vec", "std_msgs.*", "tf2_msgs.*", "torchreid", + "turbojpeg", "ultralytics.*", "unitree_webrtc_connect.*", "xarm.*", diff --git a/uv.lock b/uv.lock index 0d6a3a88ab..b940d88ff8 100644 --- a/uv.lock +++ b/uv.lock @@ -1809,6 +1809,7 @@ dds = [ { name = "types-pysocks" }, { name = "types-pytz" }, { name = "types-pyyaml" }, + { name = "types-requests" }, { name = "types-simplejson" }, { name = "types-tabulate" }, { name = "types-tensorflow" }, @@ -1848,6 +1849,7 @@ dev = [ { name = "types-pysocks" }, { name = "types-pytz" }, { name = "types-pyyaml" }, + { name = "types-requests" }, { name = "types-simplejson" }, { name = "types-tabulate" }, { name = "types-tensorflow" }, @@ -2128,12 +2130,12 @@ requires-dist = [ { name = "types-jsonschema", marker = "extra == 'dev'", specifier = ">=4.25.1.20251009,<5" }, { name = "types-networkx", marker = "extra == 'dev'", specifier = ">=3.5.0.20251001,<4" }, { name = "types-protobuf", marker = "extra == 'dev'", specifier = ">=6.32.1.20250918,<7" }, - { name = "types-psutil", marker = "extra == 'dev'", specifier = ">=7.0.0.20251001,<8" }, { name = "types-psutil", marker = "extra == 'dev'", specifier = ">=7.2.2.20260130,<8" }, { name = "types-psycopg2", marker = "extra == 'dev'", specifier = ">=2.9.21.20251012" }, { name = "types-pysocks", marker = "extra == 'dev'", specifier = ">=1.7.1.20251001,<2" }, { name = "types-pytz", marker = "extra == 'dev'", specifier = ">=2025.2.0.20250809,<2026" }, { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.12.20250915,<7" }, + { name = "types-requests", marker = "extra == 'dev'", specifier = ">=2.32.4.20260107,<3" }, { name = "types-simplejson", marker = "extra == 'dev'", specifier = ">=3.20.0.20250822,<4" }, { name = "types-tabulate", marker = "extra == 'dev'", specifier = ">=0.9.0.20241207,<1" }, { name = "types-tensorflow", marker = "extra == 'dev'", specifier = ">=2.18.0.20251008,<3" }, From 3afde0182ca5f5bcdd5da351dfc2adceb7c2c42f Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Wed, 18 Mar 2026 20:30:38 -0700 Subject: [PATCH 17/42] fix: rename teleop blueprints, remove VisualizingTeleopModule (#1602) * removed redundant rerun teleop methods * teleop blueprints rename * pre-commit fixes * fix: phone teleop import * fix: comments --- AGENTS.md | 2 +- dimos/robot/all_blueprints.py | 16 +++--- dimos/teleop/README.md | 10 ++-- dimos/teleop/phone/README.md | 4 +- dimos/teleop/phone/blueprints.py | 8 +-- dimos/teleop/quest/README.md | 10 ++-- dimos/teleop/quest/blueprints.py | 37 +++++--------- dimos/teleop/quest/quest_extensions.py | 42 --------------- dimos/teleop/utils/teleop_visualization.py | 59 ---------------------- 9 files changed, 34 insertions(+), 154 deletions(-) delete mode 100644 dimos/teleop/utils/teleop_visualization.py diff --git a/AGENTS.md b/AGENTS.md index 34c33d9a02..9a5f7f5c17 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -43,7 +43,7 @@ dimos restart # stop + re-run with same original args | `unitree-g1-agentic-sim` | G1 | sim | GPT-4o (G1 prompt) | — | Full agentic sim, no real robot needed | | `xarm-perception-agent` | xArm | real | GPT-4o | — | Manipulation + perception + agent | | `xarm7-trajectory-sim` | xArm7 | sim | — | — | Trajectory planning sim | -| `arm-teleop-xarm7` | xArm7 | real | — | — | Quest VR teleop | +| `teleop-quest-xarm7` | xArm7 | real | — | — | Quest VR teleop | | `dual-xarm6-planner` | xArm6×2 | real | — | — | Dual-arm motion planner | Run `dimos list` for the full list. diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index e82cb656ce..0d4225e463 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -16,11 +16,6 @@ # Run `pytest dimos/robot/test_all_blueprints_generation.py` to regenerate. all_blueprints = { - "arm-teleop": "dimos.teleop.quest.blueprints:arm_teleop", - "arm-teleop-dual": "dimos.teleop.quest.blueprints:arm_teleop_dual", - "arm-teleop-piper": "dimos.teleop.quest.blueprints:arm_teleop_piper", - "arm-teleop-visualizing": "dimos.teleop.quest.blueprints:arm_teleop_visualizing", - "arm-teleop-xarm7": "dimos.teleop.quest.blueprints:arm_teleop_xarm7", "coordinator-basic": "dimos.control.blueprints:coordinator_basic", "coordinator-cartesian-ik-mock": "dimos.control.blueprints:coordinator_cartesian_ik_mock", "coordinator-cartesian-ik-piper": "dimos.control.blueprints:coordinator_cartesian_ik_piper", @@ -60,9 +55,13 @@ "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels", "mid360-fastlio-voxels-native": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels_native", - "phone-go2-fleet-teleop": "dimos.teleop.phone.blueprints:phone_go2_fleet_teleop", - "phone-go2-teleop": "dimos.teleop.phone.blueprints:phone_go2_teleop", - "simple-phone-teleop": "dimos.teleop.phone.blueprints:simple_phone_teleop", + "teleop-phone": "dimos.teleop.phone.blueprints:teleop_phone", + "teleop-phone-go2": "dimos.teleop.phone.blueprints:teleop_phone_go2", + "teleop-phone-go2-fleet": "dimos.teleop.phone.blueprints:teleop_phone_go2_fleet", + "teleop-quest-dual": "dimos.teleop.quest.blueprints:teleop_quest_dual", + "teleop-quest-piper": "dimos.teleop.quest.blueprints:teleop_quest_piper", + "teleop-quest-rerun": "dimos.teleop.quest.blueprints:teleop_quest_rerun", + "teleop-quest-xarm7": "dimos.teleop.quest.blueprints:teleop_quest_xarm7", "uintree-g1-primitive-no-nav": "dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav:uintree_g1_primitive_no_nav", "unitree-g1": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1:unitree_g1", "unitree-g1-agentic": "dimos.robot.unitree.g1.blueprints.agentic.unitree_g1_agentic:unitree_g1_agentic", @@ -143,7 +142,6 @@ "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory", "twist-teleop-module": "dimos.teleop.quest.quest_extensions", "unitree-skills": "dimos.robot.unitree.unitree_skill_container", - "visualizing-teleop-module": "dimos.teleop.quest.quest_extensions", "vlm-agent": "dimos.agents.vlm_agent", "vlm-stream-tester": "dimos.agents.vlm_stream_tester", "voxel-mapper": "dimos.mapping.voxels", diff --git a/dimos/teleop/README.md b/dimos/teleop/README.md index fac35ab512..c29ac5011e 100644 --- a/dimos/teleop/README.md +++ b/dimos/teleop/README.md @@ -35,9 +35,6 @@ Toggle-based engage — press primary button once to engage, press again to dise ### TwistTeleopModule Outputs TwistStamped (linear + angular velocity) instead of PoseStamped. -### VisualizingTeleopModule -Adds Rerun visualization for debugging. Extends ArmTeleopModule (toggle engage). - ### PhoneTeleopModule Base phone teleop module. Receives orientation + gyro data from phone motion sensors, computes velocity commands from orientation deltas. @@ -68,7 +65,7 @@ Filters to mobile-base axes (linear.x, linear.y, angular.z) and publishes as `Tw teleop/ ├── quest/ │ ├── quest_teleop_module.py # Base Quest teleop module -│ ├── quest_extensions.py # ArmTeleop, TwistTeleop, VisualizingTeleop +│ ├── quest_extensions.py # ArmTeleop, TwistTeleop │ ├── quest_types.py # QuestControllerState, Buttons │ └── web/ │ └── static/index.html # WebXR client @@ -80,15 +77,14 @@ teleop/ │ └── static/index.html # Mobile sensor web app ├── utils/ │ ├── teleop_transforms.py # WebXR → robot frame math -│ └── teleop_visualization.py # Rerun visualization helpers └── blueprints.py # Module blueprints for easy instantiation ``` ## Quick Start ```bash -dimos run arm-teleop # Quest arm teleop -dimos run phone-go2-teleop # Phone → Go2 +dimos run teleop-quest-rerun # Quest teleop + Rerun viz +dimos run teleop-phone-go2 # Phone → Go2 ``` Open `https://:/teleop` on device. Accept the self-signed certificate. diff --git a/dimos/teleop/phone/README.md b/dimos/teleop/phone/README.md index 5f8541f602..da84bfd124 100644 --- a/dimos/teleop/phone/README.md +++ b/dimos/teleop/phone/README.md @@ -12,8 +12,8 @@ Phone Browser ──WebSocket──→ Embedded HTTPS Server ──→ Phone ## Running ```bash -dimos run phone-go2-teleop # Go2 -dimos run simple-phone-teleop # Generic ground robot +dimos run teleop-phone-go2 # Go2 +dimos run teleop-phone # Generic ground robot ``` Open `https://:8444/teleop` on phone. Accept cert, allow sensors, connect, hold to drive. diff --git a/dimos/teleop/phone/blueprints.py b/dimos/teleop/phone/blueprints.py index 6e68e726e3..86e1154d92 100644 --- a/dimos/teleop/phone/blueprints.py +++ b/dimos/teleop/phone/blueprints.py @@ -19,21 +19,21 @@ from dimos.teleop.phone.phone_extensions import simple_phone_teleop_module # Simple phone teleop (mobile base axis filtering + cmd_vel output) -simple_phone_teleop = autoconnect( +teleop_phone = autoconnect( simple_phone_teleop_module(), ) # Phone teleop wired to Unitree Go2 -phone_go2_teleop = autoconnect( +teleop_phone_go2 = autoconnect( simple_phone_teleop_module(), unitree_go2_basic, ) # Phone teleop wired to Go2 fleet — twist commands sent to all robots -phone_go2_fleet_teleop = autoconnect( +teleop_phone_go2_fleet = autoconnect( simple_phone_teleop_module(), unitree_go2_fleet, ) -__all__ = ["phone_go2_fleet_teleop", "phone_go2_teleop", "simple_phone_teleop"] +__all__ = ["teleop_phone", "teleop_phone_go2", "teleop_phone_go2_fleet"] diff --git a/dimos/teleop/quest/README.md b/dimos/teleop/quest/README.md index 0b0e2b8402..4e8164ec9b 100644 --- a/dimos/teleop/quest/README.md +++ b/dimos/teleop/quest/README.md @@ -12,10 +12,10 @@ Quest Browser ──WebSocket──→ Embedded HTTPS Server ──→ Quest ## Running ```bash -dimos run arm-teleop # Basic arm teleop -dimos run arm-teleop-xarm6 # XArm6 -dimos run arm-teleop-piper # Piper -dimos run arm-teleop-dual # Dual arm +dimos run teleop-quest-rerun # Quest teleop + Rerun viz +dimos run teleop-quest-xarm7 # XArm7 +dimos run teleop-quest-piper # Piper +dimos run teleop-quest-dual # Dual arm ``` Open `https://:8443/teleop` on Quest browser. Accept cert, tap Connect. @@ -42,7 +42,7 @@ Open `https://:8443/teleop` on Quest browser. Accept cert, tap Connect. ``` quest/ ├── quest_teleop_module.py # Base module -├── quest_extensions.py # ArmTeleop, TwistTeleop, VisualizingTeleop +├── quest_extensions.py # ArmTeleop, TwistTeleop ├── quest_types.py # QuestControllerState, Buttons ├── blueprints.py └── web/static/index.html # WebXR client diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index a3aa54ee08..da07a1bdd4 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -23,23 +23,14 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.teleop.quest.quest_extensions import arm_teleop_module, visualizing_teleop_module +from dimos.teleop.quest.quest_extensions import arm_teleop_module from dimos.teleop.quest.quest_types import Buttons +from dimos.visualization.rerun.bridge import rerun_bridge -# Arm teleop with press-and-hold engage -arm_teleop = autoconnect( +# Arm teleop with press-and-hold engage (has rerun viz) +teleop_quest_rerun = autoconnect( arm_teleop_module(), -).transports( - { - ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), - ("right_controller_output", PoseStamped): LCMTransport("/teleop/right_delta", PoseStamped), - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - } -) - -# Arm teleop with Rerun visualization -arm_teleop_visualizing = autoconnect( - visualizing_teleop_module(), + rerun_bridge(), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), @@ -50,9 +41,7 @@ # Single XArm7 teleop: right controller -> xarm7 -# Usage: dimos run arm-teleop-xarm7 - -arm_teleop_xarm7 = autoconnect( +teleop_quest_xarm7 = autoconnect( arm_teleop_module(task_names={"right": "teleop_xarm"}), coordinator_teleop_xarm7, ).transports( @@ -66,8 +55,7 @@ # Single Piper teleop: left controller -> piper arm -# Usage: dimos run arm-teleop-piper -arm_teleop_piper = autoconnect( +teleop_quest_piper = autoconnect( arm_teleop_module(task_names={"left": "teleop_piper"}), coordinator_teleop_piper, ).transports( @@ -81,7 +69,7 @@ # Dual arm teleop: right -> piper, left -> xarm6 (TeleopIK) -arm_teleop_dual = autoconnect( +teleop_quest_dual = autoconnect( arm_teleop_module(task_names={"right": "teleop_piper", "left": "teleop_xarm"}), coordinator_teleop_dual, ).transports( @@ -98,9 +86,8 @@ __all__ = [ - "arm_teleop", - "arm_teleop_dual", - "arm_teleop_piper", - "arm_teleop_visualizing", - "arm_teleop_xarm7", + "teleop_quest_dual", + "teleop_quest_piper", + "teleop_quest_rerun", + "teleop_quest_xarm7", ] diff --git a/dimos/teleop/quest/quest_extensions.py b/dimos/teleop/quest/quest_extensions.py index 46e868837d..674fc36f1e 100644 --- a/dimos/teleop/quest/quest_extensions.py +++ b/dimos/teleop/quest/quest_extensions.py @@ -17,7 +17,6 @@ Available subclasses: - ArmTeleopModule: Per-hand press-and-hold engage (X/A hold to track), task name routing - TwistTeleopModule: Outputs Twist instead of PoseStamped - - VisualizingTeleopModule: Adds Rerun visualization (inherits press-and-hold engage) """ from typing import Any @@ -29,10 +28,6 @@ from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.teleop.quest.quest_teleop_module import Hand, QuestTeleopConfig, QuestTeleopModule from dimos.teleop.quest.quest_types import Buttons, QuestControllerState -from dimos.teleop.utils.teleop_visualization import ( - visualize_buttons, - visualize_pose, -) class TwistTeleopConfig(QuestTeleopConfig): @@ -138,51 +133,14 @@ def _publish_button_state( self.buttons.publish(buttons) -class VisualizingTeleopModule(ArmTeleopModule): - """Quest teleop with Rerun visualization. - - Adds visualization of controller poses and trigger values to Rerun. - Useful for debugging and development. - - Outputs: - - left_controller_output: PoseStamped (inherited) - - right_controller_output: PoseStamped (inherited) - - buttons: Buttons (inherited) - """ - - def _get_output_pose(self, hand: Hand) -> PoseStamped | None: - """Get output pose and visualize in Rerun.""" - output_pose = super()._get_output_pose(hand) - - if output_pose is not None: - current_pose = self._current_poses.get(hand) - controller = self._controllers.get(hand) - if current_pose is not None: - label = "left" if hand == Hand.LEFT else "right" - visualize_pose(current_pose, label) - - if controller: - visualize_buttons( - label, - primary=controller.primary, - secondary=controller.secondary, - grip=controller.grip, - trigger=controller.trigger, - ) - return output_pose - - # Module blueprints for easy instantiation twist_teleop_module = TwistTeleopModule.blueprint arm_teleop_module = ArmTeleopModule.blueprint -visualizing_teleop_module = VisualizingTeleopModule.blueprint __all__ = [ "ArmTeleopConfig", "ArmTeleopModule", "TwistTeleopModule", - "VisualizingTeleopModule", "arm_teleop_module", "twist_teleop_module", - "visualizing_teleop_module", ] diff --git a/dimos/teleop/utils/teleop_visualization.py b/dimos/teleop/utils/teleop_visualization.py deleted file mode 100644 index 5a7acd06e9..0000000000 --- a/dimos/teleop/utils/teleop_visualization.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Teleop visualization utilities for Rerun.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import rerun as rr - -from dimos.utils.logging_config import setup_logger - -if TYPE_CHECKING: - from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped - -logger = setup_logger() - - -def visualize_pose(pose_stamped: PoseStamped, controller_label: str) -> None: - """Visualize controller absolute pose in Rerun.""" - try: - rr.log(f"world/teleop/{controller_label}_controller", pose_stamped.to_rerun()) # type: ignore[no-untyped-call] - rr.log(f"world/teleop/{controller_label}_controller/axes", rr.TransformAxes3D(0.10)) # type: ignore[attr-defined] - except Exception as e: - logger.debug(f"Failed to log {controller_label} controller to Rerun: {e}") - - -def visualize_buttons( - controller_label: str, - primary: bool = False, - secondary: bool = False, - grip: float = 0.0, - trigger: float = 0.0, -) -> None: - """Visualize button states in Rerun as scalar time series.""" - try: - base_path = f"world/teleop/{controller_label}_controller" - rr.log(f"{base_path}/primary", rr.Scalars(float(primary))) # type: ignore[attr-defined] - rr.log(f"{base_path}/secondary", rr.Scalars(float(secondary))) # type: ignore[attr-defined] - rr.log(f"{base_path}/grip", rr.Scalars(grip)) # type: ignore[attr-defined] - rr.log(f"{base_path}/trigger", rr.Scalars(trigger)) # type: ignore[attr-defined] - except Exception as e: - logger.debug(f"Failed to log {controller_label} buttons to Rerun: {e}") - - -__all__ = ["visualize_buttons", "visualize_pose"] From c6f1842e477c618ab029ab842d61b6cd74d411dd Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Thu, 19 Mar 2026 08:05:03 +0200 Subject: [PATCH 18/42] feat(test): add leaderboard (#1580) --- bin/test-speed-leaderboard | 220 +++++++++++++++++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100755 bin/test-speed-leaderboard diff --git a/bin/test-speed-leaderboard b/bin/test-speed-leaderboard new file mode 100755 index 0000000000..d58bbe1a9e --- /dev/null +++ b/bin/test-speed-leaderboard @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +import ast +from collections import defaultdict +import os +import re +import subprocess +import sys +import tempfile +import xml.etree.ElementTree as ET + +ANSI_RE = re.compile(r"\x1b\[[0-9;]*m") + + +def get_repo_root(): + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() + + +def classname_to_filepath(classname): + parts = classname.split(".") + class_name = None + while len(parts) > 1 and parts[-1][0:1].isupper(): + class_name = parts.pop() + return os.path.join(*parts) + ".py", class_name + + +def run_pytest(extra_args): + """Run pytest, return list of (file_path, class_name|None, func_name, duration).""" + xml_file = tempfile.NamedTemporaryFile(suffix=".xml", delete=False) + xml_file.close() + try: + cmd = [ + "pytest", + "dimos", + "-m", + "not (tool or mujoco)", + f"--junit-xml={xml_file.name}", + "--tb=no", + *extra_args, + ] + print("$ " + " ".join(cmd), file=sys.stderr, flush=True) + result = subprocess.run(cmd, capture_output=True, text=True) + + # Show the final summary line + for raw in reversed(result.stdout.splitlines()): + line = ANSI_RE.sub("", raw).strip() + if "passed" in line or "failed" in line or "error" in line: + print(line, file=sys.stderr) + break + + if result.returncode not in (0, 1): + print(result.stderr, file=sys.stderr) + sys.exit(result.returncode) + + tree = ET.parse(xml_file.name) + finally: + os.unlink(xml_file.name) + + tests = [] + for tc in tree.iter("testcase"): + classname = tc.get("classname", "") + name = tc.get("name", "") + time_s = float(tc.get("time", "0")) + # Strip parametrize suffixes like [param1-param2] + func_name = re.sub(r"\[.*\]$", "", name) + file_path, class_name = classname_to_filepath(classname) + tests.append((file_path, class_name, func_name, time_s)) + return tests + + +def find_function_lines(tree, class_name, func_name): + """Return (start_line, end_line) of a test function in an AST.""" + for node in ast.walk(tree): + if class_name and isinstance(node, ast.ClassDef) and node.name == class_name: + for item in node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if item.name == func_name: + return item.lineno, item.end_lineno + elif not class_name and isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name == func_name: + return node.lineno, node.end_lineno + return None, None + + +def blame_lines(file_path, start, end): + """Return {author: line_count} for a line range via git blame.""" + try: + result = subprocess.run( + ["git", "blame", "--line-porcelain", f"-L{start},{end}", file_path], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError: + return {} + + counts = defaultdict(int) + current_line = None + for line in result.stdout.splitlines(): + tokens = line.split() + if ( + len(tokens) >= 3 + and len(tokens[0]) == 40 + and all(c in "0123456789abcdef" for c in tokens[0]) + ): + current_line = int(tokens[2]) + elif line.startswith("author ") and current_line is not None: + counts[line[7:]] += 1 + return dict(counts) + + +def format_time(seconds): + if seconds >= 60: + m, s = divmod(seconds, 60) + return f"{int(m)}m {s:.1f}s" + if seconds >= 1: + return f"{seconds:.2f}s" + if seconds >= 0.001: + return f"{seconds * 1000:.1f}ms" + return f"{seconds * 1_000_000:.0f}us" + + +def main(): + repo_root = get_repo_root() + os.chdir(repo_root) + + extra_args = sys.argv[1:] + + print("Running pytest to collect test durations...", file=sys.stderr, flush=True) + tests = run_pytest(extra_args) + if not tests: + print("No test results found.", file=sys.stderr) + sys.exit(1) + + print( + f"Analysing {len(tests)} tests with git blame...", + file=sys.stderr, + flush=True, + ) + + stats = defaultdict(lambda: {"time": 0.0, "lines": 0}) + ast_cache = {} + blame_cache = {} + skipped = 0 + + for file_path, class_name, func_name, duration in tests: + if not func_name: + skipped += 1 + continue + + if file_path not in ast_cache: + try: + with open(file_path) as f: + ast_cache[file_path] = ast.parse(f.read()) + except (FileNotFoundError, SyntaxError): + ast_cache[file_path] = None + tree = ast_cache[file_path] + if tree is None: + skipped += 1 + continue + + start, end = find_function_lines(tree, class_name, func_name) + if start is None: + skipped += 1 + continue + + cache_key = (file_path, start, end) + if cache_key not in blame_cache: + blame_cache[cache_key] = blame_lines(file_path, start, end) + author_lines = blame_cache[cache_key] + total_lines = sum(author_lines.values()) + if total_lines == 0: + skipped += 1 + continue + + for author, lines in author_lines.items(): + frac = lines / total_lines + stats[author]["time"] += duration * frac + stats[author]["lines"] += lines + + if not stats: + print("No blame data could be collected.", file=sys.stderr) + sys.exit(1) + if skipped: + print(f"({skipped} tests skipped — could not resolve source)", file=sys.stderr) + + # Sort by time-per-line ascending (fastest first) + ranked = sorted( + stats.items(), + key=lambda x: x[1]["time"] / x[1]["lines"] if x[1]["lines"] else float("inf"), + ) + + total_time = sum(s["time"] for _, s in ranked) + total_lines = sum(s["lines"] for _, s in ranked) + + print() + hdr = f" {'Committer':<30} {'Total Time':>12} {'Lines':>7} {'Time / Line':>12}" + sep = " " + "\u2500" * (len(hdr) - 2) + print(hdr) + print(sep) + for rank, (author, s) in enumerate(ranked, 1): + t = s["time"] + n = s["lines"] + tpl = t / n if n else 0 + print(f"{rank:>2}. {author:<28} {format_time(t):>12} {n:>7} {format_time(tpl):>12}") + print(sep) + tpl_all = total_time / total_lines if total_lines else 0 + print( + f" {'TOTAL':<30} {format_time(total_time):>12} {total_lines:>7} {format_time(tpl_all):>12}" + ) + + +if __name__ == "__main__": + main() From bdd06d4e04186f3fd0749e7798a9d8b2887b9a97 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Thu, 19 Mar 2026 08:07:59 +0200 Subject: [PATCH 19/42] fix(florence): fix text failure (#1582) --- dimos/models/vl/florence.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 6fa7ba3d12..143d4c7ea9 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -31,6 +31,16 @@ class CaptionDetail(Enum): NORMAL = "" DETAILED = "" + @classmethod + def from_str(cls, name: str) -> "CaptionDetail": + _ALIASES: dict[str, CaptionDetail] = { + "brief": cls.BRIEF, + "normal": cls.NORMAL, + "detailed": cls.DETAILED, + "more_detailed": cls.DETAILED, + } + return _ALIASES.get(name.lower()) or cls[name.upper()] + class Florence2Model(HuggingFaceModel, Captioner): """Florence-2 captioning model from Microsoft. @@ -74,13 +84,18 @@ def _clean_caption(text: str) -> str: return text[len(prefix) :] return text - def caption(self, image: Image) -> str: + def caption(self, image: Image, detail: str | CaptionDetail | None = None) -> str: """Generate a caption for the image. Returns: Text description of the image """ - task_prompt = self._task_prompt + if detail is None: + task_prompt = self._task_prompt + elif isinstance(detail, CaptionDetail): + task_prompt = detail.value + else: + task_prompt = CaptionDetail.from_str(detail).value # Convert to PIL pil_image = PILImage.fromarray(image.to_rgb().data) From b9cca6c38aee0b2fa17e467606e5e984b3ee8e50 Mon Sep 17 00:00:00 2001 From: leshy Date: Thu, 19 Mar 2026 10:26:54 +0200 Subject: [PATCH 20/42] event based sub callback collector for tests (#1605) * event based sub callback collector for tests * shorter wait for no msg * fix(tests): raise AssertionError on CallbackCollector timeout Instead of silently returning when messages never arrive, wait() now raises with a clear message showing expected vs received count. --- dimos/protocol/pubsub/impl/test_lcmpubsub.py | 77 +++++---------- dimos/protocol/pubsub/impl/test_rospubsub.py | 98 +++++++------------- dimos/protocol/pubsub/test_pattern_sub.py | 72 +++++++------- dimos/protocol/pubsub/test_spec.py | 78 ++++------------ dimos/utils/testing/collector.py | 50 ++++++++++ 5 files changed, 159 insertions(+), 216 deletions(-) create mode 100644 dimos/utils/testing/collector.py diff --git a/dimos/protocol/pubsub/impl/test_lcmpubsub.py b/dimos/protocol/pubsub/impl/test_lcmpubsub.py index ba29c70958..c53bc32da2 100644 --- a/dimos/protocol/pubsub/impl/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/test_lcmpubsub.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Generator -import time from typing import Any import pytest @@ -27,6 +26,7 @@ PickleLCM, Topic, ) +from dimos.utils.testing.collector import CallbackCollector @pytest.fixture @@ -74,25 +74,19 @@ def __eq__(self, other: object) -> bool: def test_LCMPubSubBase_pubsub(lcm_pub_sub_base: LCMPubSubBase) -> None: lcm = lcm_pub_sub_base - - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message.lcm_encode()) - time.sleep(0.1) + collector.wait() - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, bytes) assert received_data.decode() == "test_data" @@ -102,24 +96,19 @@ def callback(msg: Any, topic: Any) -> None: def test_lcm_autodecoder_pubsub(lcm: LCM) -> None: - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) - time.sleep(0.1) + collector.wait() - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, MockLCMMessage) assert received_data == test_message @@ -138,24 +127,18 @@ def callback(msg: Any, topic: Any) -> None: # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_pubsub(test_message: Any, lcm: LCM) -> None: - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) + collector.wait() - time.sleep(0.1) - - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, test_message.__class__) assert received_data == test_message @@ -163,36 +146,26 @@ def callback(msg: Any, topic: Any) -> None: assert isinstance(received_topic, Topic) assert received_topic == topic - print(test_message, topic) - # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_autopickle_pubsub(test_message: Any, pickle_lcm: PickleLCM) -> None: lcm = pickle_lcm - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) + collector.wait() - time.sleep(0.1) + assert len(collector.results) == 1 - assert len(received_messages) == 1 - - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, test_message.__class__) assert received_data == test_message assert isinstance(received_topic, Topic) assert received_topic == topic - - print(test_message, topic) diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index ef9df74227..6f29b3591b 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Generator -import threading from dimos_lcm.geometry_msgs import PointStamped import numpy as np @@ -28,6 +27,7 @@ # Add msg_name to LCM PointStamped for testing nested message conversion PointStamped.msg_name = "geometry_msgs.PointStamped" from dimos.utils.data import get_data +from dimos.utils.testing.collector import CallbackCollector from dimos.utils.testing.replay import TimedSensorReplay @@ -57,20 +57,14 @@ def test_basic_conversion(publisher, subscriber): Simple flat dimos.msgs type with no nesting (just x/y/z floats). """ topic = ROSTopic("/test_ros_topic", Vector3) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, Vector3(1.0, 2.0, 3.0)) - assert event.wait(timeout=2.0), "No message received" - assert len(received) == 1 - msg = received[0] + collector.wait() + assert len(collector.results) == 1 + msg = collector.results[0][0] assert msg.x == 1.0 assert msg.y == 2.0 assert msg.z == 3.0 @@ -95,21 +89,15 @@ def test_pointcloud2_pubsub(publisher, subscriber): assert len(original) > 0, "Loaded empty pointcloud" topic = ROSTopic("/test_pointcloud2", PointCloud2) + collector = CallbackCollector(1, timeout=5.0) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=5.0), "No PointCloud2 message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify point cloud data is preserved original_points, _ = original.as_numpy() @@ -147,20 +135,14 @@ def test_pointcloud2_empty_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_empty_pointcloud", PointCloud2) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No empty PointCloud2 message received" - assert len(received) == 1 - assert len(received[0]) == 0 + collector.wait() + assert len(collector.results) == 1 + assert len(collector.results[0][0]) == 0 @pytest.mark.skipif_no_ros @@ -178,21 +160,15 @@ def test_posestamped_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_posestamped", PoseStamped) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No PoseStamped message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify all fields preserved assert converted.frame_id == original.frame_id @@ -220,21 +196,15 @@ def test_pointstamped_pubsub(publisher, subscriber): original.point.z = 3.5 topic = ROSTopic("/test_pointstamped", PointStamped) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No PointStamped message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify nested header fields are preserved assert converted.header.frame_id == original.header.frame_id @@ -260,21 +230,15 @@ def test_twist_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_twist", Twist) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No Twist message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify linear velocity preserved assert converted.linear.x == original.linear.x diff --git a/dimos/protocol/pubsub/test_pattern_sub.py b/dimos/protocol/pubsub/test_pattern_sub.py index ac94ba1b3b..4b888f4bba 100644 --- a/dimos/protocol/pubsub/test_pattern_sub.py +++ b/dimos/protocol/pubsub/test_pattern_sub.py @@ -30,6 +30,7 @@ from dimos.protocol.pubsub.impl.lcmpubsub import LCM, LCMPubSubBase, Topic from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub, PubSub +from dimos.utils.testing.collector import CallbackCollector TopicT = TypeVar("TopicT") MsgT = TypeVar("MsgT") @@ -139,22 +140,20 @@ def _topic_matches_prefix(topic: Any, prefix: str = "/") -> bool: @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None: """Test that subscribe_all receives messages from all topics.""" - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(tc.topic_values)) with tc.pubsub_context() as (pub, sub): - # Filter to only our test topics (LCM multicast can leak from parallel tests) - sub.subscribe_all(lambda msg, topic: received.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + sub.subscribe_all(collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(tc.topic_values) + assert len(collector.results) == len(tc.topic_values) - # Verify all messages were received - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] expected_msgs = [v for _, v in tc.topic_values] for expected in expected_msgs: assert expected in received_msgs @@ -163,47 +162,45 @@ def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None: @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_unsubscribe(tc: Case[Any, Any]) -> None: """Test that unsubscribe stops receiving messages.""" - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic, value = tc.topic_values[0] with tc.pubsub_context() as (pub, sub): - unsub = sub.subscribe_all(lambda msg, topic: received.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + unsub = sub.subscribe_all(collector) + time.sleep(0.01) # Allow subscription to register pub.publish(topic, value) - time.sleep(0.01) - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 unsub() pub.publish(topic, value) - time.sleep(0.01) - assert len(received) == 1 # No new messages + time.sleep(0.1) # Wait to confirm no new messages arrive + assert len(collector.results) == 1 # No new messages @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None: """Test that subscribe_all coexists with regular subscriptions.""" - all_received: list[tuple[Any, Any]] = [] + all_collector = CallbackCollector(2) specific_received: list[tuple[Any, Any]] = [] topic1, value1 = tc.topic_values[0] topic2, value2 = tc.topic_values[1] with tc.pubsub_context() as (pub, sub): sub.subscribe_all( - lambda msg, topic: all_received.append((msg, topic)) - if _topic_matches_prefix(topic) - else None + lambda msg, topic: all_collector(msg, topic) if _topic_matches_prefix(topic) else None ) sub.subscribe(topic1, lambda msg, topic: specific_received.append((msg, topic))) - time.sleep(0.01) # Allow subscriptions to be ready + time.sleep(0.01) # Allow subscriptions to register pub.publish(topic1, value1) pub.publish(topic2, value2) - time.sleep(0.01) + all_collector.wait() # subscribe_all gets both - assert len(all_received) == 2 + assert len(all_collector.results) == 2 # specific subscription gets only topic1 assert len(specific_received) == 1 @@ -214,25 +211,24 @@ def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None: def test_subscribe_glob(tc: Case[Any, Any]) -> None: """Test that glob pattern subscriptions receive only matching topics.""" for pattern_topic, expected_indices in tc.glob_patterns: - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(expected_indices)) with tc.pubsub_context() as (pub, sub): - sub.subscribe(pattern_topic, lambda msg, topic, r=received: r.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + sub.subscribe(pattern_topic, collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(expected_indices), ( + assert len(collector.results) == len(expected_indices), ( f"Expected {len(expected_indices)} messages for pattern {pattern_topic}, " - f"got {len(received)}" + f"got {len(collector.results)}" ) - # Verify we received the expected messages expected_msgs = [tc.topic_values[i][1] for i in expected_indices] - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] for expected in expected_msgs: assert expected in received_msgs @@ -241,25 +237,23 @@ def test_subscribe_glob(tc: Case[Any, Any]) -> None: def test_subscribe_regex(tc: Case[Any, Any]) -> None: """Test that regex pattern subscriptions receive only matching topics.""" for pattern_topic, expected_indices in tc.regex_patterns: - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(expected_indices)) with tc.pubsub_context() as (pub, sub): - sub.subscribe(pattern_topic, lambda msg, topic, r=received: r.append((msg, topic))) - - time.sleep(0.01) + sub.subscribe(pattern_topic, collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(expected_indices), ( + assert len(collector.results) == len(expected_indices), ( f"Expected {len(expected_indices)} messages for pattern {pattern_topic}, " - f"got {len(received)}" + f"got {len(collector.results)}" ) - # Verify we received the expected messages expected_msgs = [tc.topic_values[i][1] for i in expected_indices] - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] for expected in expected_msgs: assert expected in received_msgs diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index e36741bbfd..0e61132c1c 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -17,7 +17,6 @@ import asyncio from collections.abc import Callable, Generator from contextlib import contextmanager -import threading import time from typing import Any @@ -26,6 +25,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.impl.memory import Memory +from dimos.utils.testing.collector import CallbackCollector @contextmanager @@ -148,26 +148,14 @@ def shared_memory_cpu_context() -> Generator[PickleSharedMemory, None, None]: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: with pubsub_context() as x: - # Create a list to capture received messages - received_messages: list[Any] = [] - msg_event = threading.Event() - - # Define callback function that stores received messages - def callback(message: Any, _: Any) -> None: - received_messages.append(message) - msg_event.set() - - # Subscribe to the topic with our callback - x.subscribe(topic, callback) + collector = CallbackCollector(1) - # Publish the first value to the topic + x.subscribe(topic, collector) x.publish(topic, values[0]) + collector.wait() - assert msg_event.wait(timeout=1.0), "Timed out waiting for message" - - # Verify the callback was called with the correct value - assert len(received_messages) == 1 - assert received_messages[0] == values[0] + assert len(collector.results) == 1 + assert collector.results[0][0] == values[0] @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @@ -176,36 +164,21 @@ def test_multiple_subscribers( ) -> None: """Test that multiple subscribers receive the same message.""" with pubsub_context() as x: - # Create lists to capture received messages for each subscriber - received_messages_1: list[Any] = [] - received_messages_2: list[Any] = [] - event_1 = threading.Event() - event_2 = threading.Event() - - # Define callback functions - def callback_1(message: Any, topic: Any) -> None: - received_messages_1.append(message) - event_1.set() + collector_1 = CallbackCollector(1) + collector_2 = CallbackCollector(1) - def callback_2(message: Any, topic: Any) -> None: - received_messages_2.append(message) - event_2.set() + x.subscribe(topic, collector_1) + x.subscribe(topic, collector_2) - # Subscribe both callbacks to the same topic - x.subscribe(topic, callback_1) - x.subscribe(topic, callback_2) - - # Publish the first value x.publish(topic, values[0]) - assert event_1.wait(timeout=1.0), "Timed out waiting for subscriber 1" - assert event_2.wait(timeout=1.0), "Timed out waiting for subscriber 2" + collector_1.wait() + collector_2.wait() - # Verify both callbacks received the message - assert len(received_messages_1) == 1 - assert received_messages_1[0] == values[0] - assert len(received_messages_2) == 1 - assert received_messages_2[0] == values[0] + assert len(collector_1.results) == 1 + assert collector_1.results[0][0] == values[0] + assert len(collector_2.results) == 1 + assert collector_2.results[0][0] == values[0] @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @@ -241,28 +214,17 @@ def test_multiple_messages( ) -> None: """Test that subscribers receive multiple messages in order.""" with pubsub_context() as x: - # Create a list to capture received messages - received_messages: list[Any] = [] - all_received = threading.Event() - - # Publish the rest of the values (after the first one used in basic tests) messages_to_send = values[1:] if len(values) > 1 else values + collector = CallbackCollector(len(messages_to_send)) - # Define callback function - def callback(message: Any, topic: Any) -> None: - received_messages.append(message) - if len(received_messages) >= len(messages_to_send): - all_received.set() - - # Subscribe to the topic - x.subscribe(topic, callback) + x.subscribe(topic, collector) for msg in messages_to_send: x.publish(topic, msg) - assert all_received.wait(timeout=1.0), "Timed out waiting for all messages" + collector.wait() - # Verify all messages were received in order + received_messages = [r[0] for r in collector.results] assert len(received_messages) == len(messages_to_send) assert received_messages == messages_to_send diff --git a/dimos/utils/testing/collector.py b/dimos/utils/testing/collector.py new file mode 100644 index 0000000000..bcc3150e73 --- /dev/null +++ b/dimos/utils/testing/collector.py @@ -0,0 +1,50 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Callback collector with Event-based synchronization for async tests.""" + +import threading +from typing import Any + + +class CallbackCollector: + """Callable that collects ``(msg, topic)`` pairs and signals when *n* arrive. + + Designed as a drop-in subscription callback for pubsub tests:: + + collector = CallbackCollector(3) + sub.subscribe(topic, collector) + # ... publish 3 messages ... + collector.wait() + assert len(collector.results) == 3 + """ + + def __init__(self, n: int, timeout: float = 2.0) -> None: + self.results: list[tuple[Any, Any]] = [] + self._done = threading.Event() + self._n = n + self.timeout = timeout + + def __call__(self, msg: Any, topic: Any) -> None: + self.results.append((msg, topic)) + if len(self.results) >= self._n: + self._done.set() + + def wait(self) -> None: + """Block until *n* items have been collected, or *timeout* expires.""" + if not self._done.wait(self.timeout): + raise AssertionError( + f"Timed out after {self.timeout}s waiting for {self._n} messages " + f"(got {len(self.results)})" + ) From cb648f56aec6baa89e8f5d29f2aee6b1ecda75eb Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Thu, 19 Mar 2026 01:30:40 -0700 Subject: [PATCH 21/42] refactor: split control blueprints + added env variables (#1601) * feat: adding arm_ip and can_port to env * feat: using env variables in blueprints * arm_ip env variables * misc: control blueprints cleanup * refactor: hardware factories * fix: pre-commit checks * fix: gripper check + comments * fix: gripper addition * fix: no init needed, blueprint path * CI code cleanup * check trigger commit * fix: unwanted changes * fix: blueprint path * fix: remove duplicates * feat: env var from globalconfig --- dimos/control/blueprints.py | 692 ------------------ dimos/control/blueprints/_hardware.py | 93 +++ dimos/control/blueprints/basic.py | 117 +++ dimos/control/blueprints/dual.py | 104 +++ dimos/control/blueprints/mobile.py | 79 ++ dimos/control/blueprints/teleop.py | 249 +++++++ .../examples/twist_base_keyboard_teleop.py | 2 +- dimos/core/global_config.py | 3 + dimos/robot/all_blueprints.py | 36 +- dimos/teleop/quest/blueprints.py | 2 +- 10 files changed, 665 insertions(+), 712 deletions(-) delete mode 100644 dimos/control/blueprints.py create mode 100644 dimos/control/blueprints/_hardware.py create mode 100644 dimos/control/blueprints/basic.py create mode 100644 dimos/control/blueprints/dual.py create mode 100644 dimos/control/blueprints/mobile.py create mode 100644 dimos/control/blueprints/teleop.py diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py deleted file mode 100644 index fff2083322..0000000000 --- a/dimos/control/blueprints.py +++ /dev/null @@ -1,692 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pre-configured blueprints for the ControlCoordinator. - -This module provides ready-to-use coordinator blueprints for common setups. - -Usage: - # Run via CLI: - dimos run coordinator-mock # Mock 7-DOF arm - dimos run coordinator-xarm7 # XArm7 real hardware - dimos run coordinator-dual-mock # Dual mock arms - - # Or programmatically: - from dimos.control.blueprints import coordinator_mock - coordinator = coordinator_mock.build() - coordinator.loop() -""" - -from __future__ import annotations - -from dimos.control.components import ( - HardwareComponent, - HardwareType, - make_gripper_joints, - make_joints, - make_twist_base_joints, -) -from dimos.control.coordinator import TaskConfig, control_coordinator -from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.teleop.quest.quest_types import Buttons -from dimos.utils.data import LfsPath - -_PIPER_MODEL_PATH = LfsPath("piper_description/mujoco_model/piper_no_gripper_description.xml") -_XARM6_MODEL_PATH = LfsPath("xarm_description/urdf/xarm6/xarm6.urdf") -_XARM7_MODEL_PATH = LfsPath("xarm_description/urdf/xarm7/xarm7.urdf") - - -# Mock 7-DOF arm (for testing) -coordinator_mock = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 7), - adapter_type="mock", - ), - ], - tasks=[ - TaskConfig( - name="traj_arm", - type="trajectory", - joint_names=[f"arm_joint{i + 1}" for i in range(7)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - -# XArm7 real hardware -coordinator_xarm7 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 7), - adapter_type="xarm", - address="192.168.2.235", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="traj_arm", - type="trajectory", - joint_names=[f"arm_joint{i + 1}" for i in range(7)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - -# XArm6 real hardware -coordinator_xarm6 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="traj_xarm", - type="trajectory", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - -# Piper arm (6-DOF, CAN bus) -coordinator_piper = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="piper", - address="can0", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="traj_piper", - type="trajectory", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - - -# Dual mock arms (7-DOF left, 6-DOF right) -coordinator_dual_mock = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="left_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("left_arm", 7), - adapter_type="mock", - ), - HardwareComponent( - hardware_id="right_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("right_arm", 6), - adapter_type="mock", - ), - ], - tasks=[ - TaskConfig( - name="traj_left", - type="trajectory", - joint_names=[f"left_arm_joint{i + 1}" for i in range(7)], - priority=10, - ), - TaskConfig( - name="traj_right", - type="trajectory", - joint_names=[f"right_arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - -# Dual XArm (XArm7 left, XArm6 right) -coordinator_dual_xarm = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="left_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("left_arm", 7), - adapter_type="xarm", - address="192.168.2.235", - auto_enable=True, - ), - HardwareComponent( - hardware_id="right_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("right_arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="traj_left", - type="trajectory", - joint_names=[f"left_arm_joint{i + 1}" for i in range(7)], - priority=10, - ), - TaskConfig( - name="traj_right", - type="trajectory", - joint_names=[f"right_arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - -# Dual arm (XArm6 + Piper) -coordinator_piper_xarm = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="xarm_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("xarm_arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - HardwareComponent( - hardware_id="piper_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("piper_arm", 6), - adapter_type="piper", - address="can0", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="traj_xarm", - type="trajectory", - joint_names=[f"xarm_arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - TaskConfig( - name="traj_piper", - type="trajectory", - joint_names=[f"piper_arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - - -# XArm6 teleop - streaming position control -coordinator_teleop_xarm6 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="servo_arm", - type="servo", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("joint_command", JointState): LCMTransport("/teleop/joint_command", JointState), - } -) - -# XArm6 velocity control - streaming velocity for joystick -coordinator_velocity_xarm6 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="velocity_arm", - type="velocity", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("joint_command", JointState): LCMTransport("/joystick/joint_command", JointState), - } -) - -# XArm6 combined (servo + velocity tasks) -coordinator_combined_xarm6 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="servo_arm", - type="servo", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - TaskConfig( - name="velocity_arm", - type="velocity", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("joint_command", JointState): LCMTransport("/control/joint_command", JointState), - } -) - - -# Mock 6-DOF arm with CartesianIK -coordinator_cartesian_ik_mock = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="mock", - ), - ], - tasks=[ - TaskConfig( - name="cartesian_ik_arm", - type="cartesian_ik", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - model_path=_PIPER_MODEL_PATH, - ee_joint_id=6, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("cartesian_command", PoseStamped): LCMTransport( - "/coordinator/cartesian_command", PoseStamped - ), - } -) - -# Piper arm with CartesianIK -coordinator_cartesian_ik_piper = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="piper", - address="can0", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="cartesian_ik_arm", - type="cartesian_ik", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - model_path=_PIPER_MODEL_PATH, - ee_joint_id=6, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("cartesian_command", PoseStamped): LCMTransport( - "/coordinator/cartesian_command", PoseStamped - ), - } -) - - -# Single XArm7 with TeleopIK -coordinator_teleop_xarm7 = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 7), - adapter_type="xarm", - address="192.168.2.235", - auto_enable=True, - gripper_joints=make_gripper_joints("arm"), - ), - ], - tasks=[ - TaskConfig( - name="teleop_xarm", - type="teleop_ik", - joint_names=[f"arm_joint{i + 1}" for i in range(7)], - priority=10, - model_path=_XARM7_MODEL_PATH, - ee_joint_id=7, - hand="right", - gripper_joint=make_gripper_joints("arm")[0], - gripper_open_pos=0.85, # xArm gripper range - gripper_closed_pos=0.0, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("cartesian_command", PoseStamped): LCMTransport( - "/coordinator/cartesian_command", PoseStamped - ), - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - } -) - -# Single Piper with TeleopIK -coordinator_teleop_piper = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 6), - adapter_type="piper", - address="can0", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="teleop_piper", - type="teleop_ik", - joint_names=[f"arm_joint{i + 1}" for i in range(6)], - priority=10, - model_path=_PIPER_MODEL_PATH, - ee_joint_id=6, - hand="left", - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("cartesian_command", PoseStamped): LCMTransport( - "/coordinator/cartesian_command", PoseStamped - ), - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - } -) - -# Dual arm teleop: XArm6 + Piper with TeleopIK -coordinator_teleop_dual = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", - hardware=[ - HardwareComponent( - hardware_id="xarm_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("xarm_arm", 6), - adapter_type="xarm", - address="192.168.1.210", - auto_enable=True, - ), - HardwareComponent( - hardware_id="piper_arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("piper_arm", 6), - adapter_type="piper", - address="can0", - auto_enable=True, - ), - ], - tasks=[ - TaskConfig( - name="teleop_xarm", - type="teleop_ik", - joint_names=[f"xarm_arm_joint{i + 1}" for i in range(6)], - priority=10, - model_path=_XARM6_MODEL_PATH, - ee_joint_id=6, - hand="left", - ), - TaskConfig( - name="teleop_piper", - type="teleop_ik", - joint_names=[f"piper_arm_joint{i + 1}" for i in range(6)], - priority=10, - model_path=_PIPER_MODEL_PATH, - ee_joint_id=6, - hand="right", - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("cartesian_command", PoseStamped): LCMTransport( - "/coordinator/cartesian_command", PoseStamped - ), - ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), - } -) - - -# Mock holonomic twist base (3-DOF: vx, vy, wz) -_base_joints = make_twist_base_joints("base") -coordinator_mock_twist_base = control_coordinator( - hardware=[ - HardwareComponent( - hardware_id="base", - hardware_type=HardwareType.BASE, - joints=_base_joints, - adapter_type="mock_twist_base", - ), - ], - tasks=[ - TaskConfig( - name="vel_base", - type="velocity", - joint_names=_base_joints, - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), - } -) - - -# Mock arm (7-DOF) + mock holonomic base (3-DOF) -_mm_base_joints = make_twist_base_joints("base") -coordinator_mobile_manip_mock = control_coordinator( - hardware=[ - HardwareComponent( - hardware_id="arm", - hardware_type=HardwareType.MANIPULATOR, - joints=make_joints("arm", 7), - adapter_type="mock", - ), - HardwareComponent( - hardware_id="base", - hardware_type=HardwareType.BASE, - joints=_mm_base_joints, - adapter_type="mock_twist_base", - ), - ], - tasks=[ - TaskConfig( - name="traj_arm", - type="trajectory", - joint_names=[f"arm_joint{i + 1}" for i in range(7)], - priority=10, - ), - TaskConfig( - name="vel_base", - type="velocity", - joint_names=_mm_base_joints, - priority=10, - ), - ], -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), - } -) - - -coordinator_basic = control_coordinator( - tick_rate=100.0, - publish_joint_state=True, - joint_state_frame_id="coordinator", -).transports( - { - ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), - } -) - - -__all__ = [ - # Raw - "coordinator_basic", - # Cartesian IK - "coordinator_cartesian_ik_mock", - "coordinator_cartesian_ik_piper", - # Streaming control - "coordinator_combined_xarm6", - # Dual arm - "coordinator_dual_mock", - "coordinator_dual_xarm", - # Mobile manipulation - "coordinator_mobile_manip_mock", - # Single arm - "coordinator_mock", - # Twist base - "coordinator_mock_twist_base", - "coordinator_piper", - "coordinator_piper_xarm", - # Teleop IK - "coordinator_teleop_dual", - "coordinator_teleop_piper", - "coordinator_teleop_xarm6", - "coordinator_teleop_xarm7", - "coordinator_velocity_xarm6", - "coordinator_xarm6", - "coordinator_xarm7", -] diff --git a/dimos/control/blueprints/_hardware.py b/dimos/control/blueprints/_hardware.py new file mode 100644 index 0000000000..a36027865a --- /dev/null +++ b/dimos/control/blueprints/_hardware.py @@ -0,0 +1,93 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hardware component factories for coordinator blueprints.""" + +from __future__ import annotations + +from dimos.control.components import ( + HardwareComponent, + HardwareType, + make_gripper_joints, + make_joints, + make_twist_base_joints, +) +from dimos.core.global_config import global_config +from dimos.utils.data import LfsPath + +XARM7_IP = global_config.xarm7_ip +XARM6_IP = global_config.xarm6_ip +CAN_PORT = global_config.can_port + +PIPER_MODEL_PATH = LfsPath("piper_description/mujoco_model/piper_no_gripper_description.xml") +XARM6_MODEL_PATH = LfsPath("xarm_description/urdf/xarm6/xarm6.urdf") +XARM7_MODEL_PATH = LfsPath("xarm_description/urdf/xarm7/xarm7.urdf") + + +def mock_arm(hw_id: str = "arm", n_joints: int = 7) -> HardwareComponent: + """Mock manipulator (no real hardware).""" + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, n_joints), + adapter_type="mock", + ) + + +def xarm7(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: + """XArm7 real hardware (7-DOF).""" + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 7), + adapter_type="xarm", + address=XARM7_IP, + auto_enable=True, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], + ) + + +def xarm6(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: + """XArm6 real hardware (6-DOF).""" + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 6), + adapter_type="xarm", + address=XARM6_IP, + auto_enable=True, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], + ) + + +def piper(hw_id: str = "arm") -> HardwareComponent: + """Piper arm (6-DOF, CAN bus).""" + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 6), + adapter_type="piper", + address=CAN_PORT, + auto_enable=True, + ) + + +def mock_twist_base(hw_id: str = "base") -> HardwareComponent: + """Mock holonomic twist base (3-DOF: vx, vy, wz).""" + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.BASE, + joints=make_twist_base_joints(hw_id), + adapter_type="mock_twist_base", + ) diff --git a/dimos/control/blueprints/basic.py b/dimos/control/blueprints/basic.py new file mode 100644 index 0000000000..7ad441ed70 --- /dev/null +++ b/dimos/control/blueprints/basic.py @@ -0,0 +1,117 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single-arm coordinator blueprints with trajectory control. + +Usage: + dimos run coordinator-mock # Mock 7-DOF arm + dimos run coordinator-xarm7 # XArm7 real hardware + dimos run coordinator-xarm6 # XArm6 real hardware + dimos run coordinator-piper # Piper arm (CAN bus) +""" + +from __future__ import annotations + +from dimos.control.blueprints._hardware import mock_arm, piper, xarm6, xarm7 +from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs.JointState import JointState + +# Minimal blueprint (no hardware, no tasks) +coordinator_basic = control_coordinator( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# Mock 7-DOF arm (for testing) +coordinator_mock = control_coordinator( + hardware=[mock_arm()], + tasks=[ + TaskConfig( + name="traj_arm", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# XArm7 real hardware +coordinator_xarm7 = control_coordinator( + hardware=[xarm7()], + tasks=[ + TaskConfig( + name="traj_arm", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# XArm6 real hardware +coordinator_xarm6 = control_coordinator( + hardware=[xarm6()], + tasks=[ + TaskConfig( + name="traj_xarm", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# Piper arm (6-DOF, CAN bus) +coordinator_piper = control_coordinator( + hardware=[piper()], + tasks=[ + TaskConfig( + name="traj_piper", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +__all__ = [ + "coordinator_basic", + "coordinator_mock", + "coordinator_piper", + "coordinator_xarm6", + "coordinator_xarm7", +] diff --git a/dimos/control/blueprints/dual.py b/dimos/control/blueprints/dual.py new file mode 100644 index 0000000000..8482316ba5 --- /dev/null +++ b/dimos/control/blueprints/dual.py @@ -0,0 +1,104 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dual-arm coordinator blueprints with trajectory control. + +Usage: + dimos run coordinator-dual-mock # Mock 7+6 DOF arms + dimos run coordinator-dual-xarm # XArm7 left + XArm6 right + dimos run coordinator-piper-xarm # XArm6 + Piper +""" + +from __future__ import annotations + +from dimos.control.blueprints._hardware import mock_arm, piper, xarm6, xarm7 +from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.core.transport import LCMTransport +from dimos.msgs.sensor_msgs.JointState import JointState + +# Dual mock arms (7-DOF left, 6-DOF right) +coordinator_dual_mock = control_coordinator( + hardware=[mock_arm("left_arm", 7), mock_arm("right_arm", 6)], + tasks=[ + TaskConfig( + name="traj_left", + type="trajectory", + joint_names=[f"left_arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + TaskConfig( + name="traj_right", + type="trajectory", + joint_names=[f"right_arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# Dual XArm (XArm7 left, XArm6 right) +coordinator_dual_xarm = control_coordinator( + hardware=[xarm7("left_arm"), xarm6("right_arm")], + tasks=[ + TaskConfig( + name="traj_left", + type="trajectory", + joint_names=[f"left_arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + TaskConfig( + name="traj_right", + type="trajectory", + joint_names=[f"right_arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# Dual arm (XArm6 + Piper) +coordinator_piper_xarm = control_coordinator( + hardware=[xarm6("xarm_arm"), piper("piper_arm")], + tasks=[ + TaskConfig( + name="traj_xarm", + type="trajectory", + joint_names=[f"xarm_arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + TaskConfig( + name="traj_piper", + type="trajectory", + joint_names=[f"piper_arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +__all__ = [ + "coordinator_dual_mock", + "coordinator_dual_xarm", + "coordinator_piper_xarm", +] diff --git a/dimos/control/blueprints/mobile.py b/dimos/control/blueprints/mobile.py new file mode 100644 index 0000000000..4ed3410b8f --- /dev/null +++ b/dimos/control/blueprints/mobile.py @@ -0,0 +1,79 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mobile manipulation coordinator blueprints. + +Usage: + dimos run coordinator-mock-twist-base # Mock holonomic base + dimos run coordinator-mobile-manip-mock # Mock arm + base +""" + +from __future__ import annotations + +from dimos.control.blueprints._hardware import mock_arm, mock_twist_base +from dimos.control.components import make_twist_base_joints +from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState + +_base_joints = make_twist_base_joints("base") + +# Mock holonomic twist base (3-DOF: vx, vy, wz) +coordinator_mock_twist_base = control_coordinator( + hardware=[mock_twist_base()], + tasks=[ + TaskConfig( + name="vel_base", + type="velocity", + joint_names=_base_joints, + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + } +) + +# Mock arm (7-DOF) + mock holonomic base (3-DOF) +coordinator_mobile_manip_mock = control_coordinator( + hardware=[mock_arm(), mock_twist_base()], + tasks=[ + TaskConfig( + name="traj_arm", + type="trajectory", + joint_names=[f"arm_joint{i + 1}" for i in range(7)], + priority=10, + ), + TaskConfig( + name="vel_base", + type="velocity", + joint_names=_base_joints, + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + } +) + + +__all__ = [ + "coordinator_mobile_manip_mock", + "coordinator_mock_twist_base", +] diff --git a/dimos/control/blueprints/teleop.py b/dimos/control/blueprints/teleop.py new file mode 100644 index 0000000000..2e922bbcbf --- /dev/null +++ b/dimos/control/blueprints/teleop.py @@ -0,0 +1,249 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Advanced control coordinator blueprints: servo, velocity, cartesian IK, and teleop IK. + +Usage: + dimos run coordinator-teleop-xarm6 # Servo streaming (XArm6) + dimos run coordinator-velocity-xarm6 # Velocity streaming (XArm6) + dimos run coordinator-combined-xarm6 # Servo + velocity (XArm6) + dimos run coordinator-cartesian-ik-mock # Cartesian IK (mock) + dimos run coordinator-cartesian-ik-piper # Cartesian IK (Piper) + dimos run coordinator-teleop-xarm7 # TeleopIK (XArm7) + dimos run coordinator-teleop-piper # TeleopIK (Piper) + dimos run coordinator-teleop-dual # TeleopIK dual arm +""" + +from __future__ import annotations + +from dimos.control.blueprints._hardware import ( + PIPER_MODEL_PATH, + XARM6_MODEL_PATH, + XARM7_MODEL_PATH, + mock_arm, + piper, + xarm6, + xarm7, +) +from dimos.control.components import make_gripper_joints +from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.teleop.quest.quest_types import Buttons + +# XArm6 teleop - streaming position control +coordinator_teleop_xarm6 = control_coordinator( + hardware=[xarm6()], + tasks=[ + TaskConfig( + name="servo_arm", + type="servo", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("joint_command", JointState): LCMTransport("/teleop/joint_command", JointState), + } +) + +# XArm6 velocity control - streaming velocity for joystick +coordinator_velocity_xarm6 = control_coordinator( + hardware=[xarm6()], + tasks=[ + TaskConfig( + name="velocity_arm", + type="velocity", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("joint_command", JointState): LCMTransport("/joystick/joint_command", JointState), + } +) + +# XArm6 combined (servo + velocity tasks) +coordinator_combined_xarm6 = control_coordinator( + hardware=[xarm6()], + tasks=[ + TaskConfig( + name="servo_arm", + type="servo", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + TaskConfig( + name="velocity_arm", + type="velocity", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("joint_command", JointState): LCMTransport("/control/joint_command", JointState), + } +) + + +# Mock 6-DOF arm with CartesianIK +coordinator_cartesian_ik_mock = control_coordinator( + hardware=[mock_arm("arm", 6)], + tasks=[ + TaskConfig( + name="cartesian_ik_arm", + type="cartesian_ik", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=PIPER_MODEL_PATH, + ee_joint_id=6, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + } +) + +# Piper arm with CartesianIK +coordinator_cartesian_ik_piper = control_coordinator( + hardware=[piper()], + tasks=[ + TaskConfig( + name="cartesian_ik_arm", + type="cartesian_ik", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=PIPER_MODEL_PATH, + ee_joint_id=6, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + } +) + + +# Single XArm7 with TeleopIK +coordinator_teleop_xarm7 = control_coordinator( + hardware=[xarm7(gripper=True)], + tasks=[ + TaskConfig( + name="teleop_xarm", + type="teleop_ik", + joint_names=[f"arm_joint{i + 1}" for i in range(7)], + priority=10, + model_path=XARM7_MODEL_PATH, + ee_joint_id=7, + hand="right", + gripper_joint=make_gripper_joints("arm")[0], + gripper_open_pos=0.85, + gripper_closed_pos=0.0, + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + } +) + +# Single Piper with TeleopIK +coordinator_teleop_piper = control_coordinator( + hardware=[piper()], + tasks=[ + TaskConfig( + name="teleop_piper", + type="teleop_ik", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=PIPER_MODEL_PATH, + ee_joint_id=6, + hand="left", + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + } +) + +# Dual arm teleop: XArm6 + Piper with TeleopIK +coordinator_teleop_dual = control_coordinator( + hardware=[xarm6("xarm_arm"), piper("piper_arm")], + tasks=[ + TaskConfig( + name="teleop_xarm", + type="teleop_ik", + joint_names=[f"xarm_arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=XARM6_MODEL_PATH, + ee_joint_id=6, + hand="left", + ), + TaskConfig( + name="teleop_piper", + type="teleop_ik", + joint_names=[f"piper_arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=PIPER_MODEL_PATH, + ee_joint_id=6, + hand="right", + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + } +) + + +__all__ = [ + # Cartesian IK + "coordinator_cartesian_ik_mock", + "coordinator_cartesian_ik_piper", + "coordinator_combined_xarm6", + "coordinator_teleop_dual", + "coordinator_teleop_piper", + # Servo / Velocity + "coordinator_teleop_xarm6", + # TeleopIK + "coordinator_teleop_xarm7", + "coordinator_velocity_xarm6", +] diff --git a/dimos/control/examples/twist_base_keyboard_teleop.py b/dimos/control/examples/twist_base_keyboard_teleop.py index 2d7651145a..610f8679e4 100644 --- a/dimos/control/examples/twist_base_keyboard_teleop.py +++ b/dimos/control/examples/twist_base_keyboard_teleop.py @@ -33,7 +33,7 @@ from __future__ import annotations -from dimos.control.blueprints import coordinator_mock_twist_base +from dimos.control.blueprints.mobile import coordinator_mock_twist_base from dimos.robot.unitree.keyboard_teleop import keyboard_teleop diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 0b070dabd9..42a7fa552a 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -29,6 +29,9 @@ def _get_all_numbers(s: str) -> list[float]: class GlobalConfig(BaseSettings): robot_ip: str | None = None robot_ips: str | None = None + xarm7_ip: str | None = None + xarm6_ip: str | None = None + can_port: str = "can0" simulation: bool = False replay: bool = False replay_dir: str = "go2_sf_office" diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0d4225e463..5fd61891c2 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -16,24 +16,24 @@ # Run `pytest dimos/robot/test_all_blueprints_generation.py` to regenerate. all_blueprints = { - "coordinator-basic": "dimos.control.blueprints:coordinator_basic", - "coordinator-cartesian-ik-mock": "dimos.control.blueprints:coordinator_cartesian_ik_mock", - "coordinator-cartesian-ik-piper": "dimos.control.blueprints:coordinator_cartesian_ik_piper", - "coordinator-combined-xarm6": "dimos.control.blueprints:coordinator_combined_xarm6", - "coordinator-dual-mock": "dimos.control.blueprints:coordinator_dual_mock", - "coordinator-dual-xarm": "dimos.control.blueprints:coordinator_dual_xarm", - "coordinator-mobile-manip-mock": "dimos.control.blueprints:coordinator_mobile_manip_mock", - "coordinator-mock": "dimos.control.blueprints:coordinator_mock", - "coordinator-mock-twist-base": "dimos.control.blueprints:coordinator_mock_twist_base", - "coordinator-piper": "dimos.control.blueprints:coordinator_piper", - "coordinator-piper-xarm": "dimos.control.blueprints:coordinator_piper_xarm", - "coordinator-teleop-dual": "dimos.control.blueprints:coordinator_teleop_dual", - "coordinator-teleop-piper": "dimos.control.blueprints:coordinator_teleop_piper", - "coordinator-teleop-xarm6": "dimos.control.blueprints:coordinator_teleop_xarm6", - "coordinator-teleop-xarm7": "dimos.control.blueprints:coordinator_teleop_xarm7", - "coordinator-velocity-xarm6": "dimos.control.blueprints:coordinator_velocity_xarm6", - "coordinator-xarm6": "dimos.control.blueprints:coordinator_xarm6", - "coordinator-xarm7": "dimos.control.blueprints:coordinator_xarm7", + "coordinator-basic": "dimos.control.blueprints.basic:coordinator_basic", + "coordinator-cartesian-ik-mock": "dimos.control.blueprints.teleop:coordinator_cartesian_ik_mock", + "coordinator-cartesian-ik-piper": "dimos.control.blueprints.teleop:coordinator_cartesian_ik_piper", + "coordinator-combined-xarm6": "dimos.control.blueprints.teleop:coordinator_combined_xarm6", + "coordinator-dual-mock": "dimos.control.blueprints.dual:coordinator_dual_mock", + "coordinator-dual-xarm": "dimos.control.blueprints.dual:coordinator_dual_xarm", + "coordinator-mobile-manip-mock": "dimos.control.blueprints.mobile:coordinator_mobile_manip_mock", + "coordinator-mock": "dimos.control.blueprints.basic:coordinator_mock", + "coordinator-mock-twist-base": "dimos.control.blueprints.mobile:coordinator_mock_twist_base", + "coordinator-piper": "dimos.control.blueprints.basic:coordinator_piper", + "coordinator-piper-xarm": "dimos.control.blueprints.dual:coordinator_piper_xarm", + "coordinator-teleop-dual": "dimos.control.blueprints.teleop:coordinator_teleop_dual", + "coordinator-teleop-piper": "dimos.control.blueprints.teleop:coordinator_teleop_piper", + "coordinator-teleop-xarm6": "dimos.control.blueprints.teleop:coordinator_teleop_xarm6", + "coordinator-teleop-xarm7": "dimos.control.blueprints.teleop:coordinator_teleop_xarm7", + "coordinator-velocity-xarm6": "dimos.control.blueprints.teleop:coordinator_velocity_xarm6", + "coordinator-xarm6": "dimos.control.blueprints.basic:coordinator_xarm6", + "coordinator-xarm7": "dimos.control.blueprints.basic:coordinator_xarm7", "demo-agent": "dimos.agents.demo_agent:demo_agent", "demo-agent-camera": "dimos.agents.demo_agent:demo_agent_camera", "demo-camera": "dimos.hardware.sensors.camera.module:demo_camera", diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index da07a1bdd4..71e16c2da8 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -15,7 +15,7 @@ """Teleop blueprints for testing and deployment.""" -from dimos.control.blueprints import ( +from dimos.control.blueprints.teleop import ( coordinator_teleop_dual, coordinator_teleop_piper, coordinator_teleop_xarm7, From cdac06e2f6b725f3ff457e79e5e1ddae8a2d5685 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 19 Mar 2026 15:15:14 -0700 Subject: [PATCH 22/42] - (#1610) --- flake.nix | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/flake.nix b/flake.nix index 68dbf0ee8c..c22b1f7791 100644 --- a/flake.nix +++ b/flake.nix @@ -160,10 +160,18 @@ nativeBuildInputs = (old.nativeBuildInputs or []) ++ [ pkgs.pkg-config pkgs.python312 ]; # 1. fix pkg-config on darwin env.PKG_CONFIG_PATH = packageConfPackagesString; - # 2. Fix fsync on darwin - patches = [ - (pkgs.writeText "lcm-darwin-fsync.patch" "--- ./lcm-logger/lcm_logger.c 2025-11-14 09:46:01.000000000 -0600\n+++ ./lcm-logger/lcm_logger.c 2025-11-14 09:47:05.000000000 -0600\n@@ -428,9 +428,13 @@\n if (needs_flushed) {\n fflush(logger->log->f);\n #ifndef WIN32\n+#ifdef __APPLE__\n+ fsync(fileno(logger->log->f));\n+#else\n // Perform a full fsync operation after flush\n fdatasync(fileno(logger->log->f));\n #endif\n+#endif\n logger->last_fflush_time = log_event->timestamp;\n }\n") - ]; + # Remove upstream patches (the darwin-fsync patch causes "out of memory" in patch utility) + patches = []; + # 2. Fix fsync on darwin (use substituteInPlace to avoid patch utility issues) + postPatch = (old.postPatch or "") + '' + substituteInPlace lcm-logger/lcm_logger.c \ + --replace-fail 'fdatasync(fileno(logger->log->f));' \ + '#ifdef __APPLE__ + fsync(fileno(logger->log->f)); + #else + fdatasync(fileno(logger->log->f)); + #endif' + ''; } ); } From c8a7b7dca736cea78cb2fd4e60f683d8d3a74f92 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Thu, 19 Mar 2026 22:29:58 -0700 Subject: [PATCH 23/42] fix(cli): fix `dimos --help` (both bad imports and speed) (#1571) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(cli): speed up `dimos --help` by extracting lightweight type aliases Move NavigationStrategy and VlModelName type aliases into dimos/core/types.py so that global_config.py no longer pulls in matplotlib/scipy (via path_map.py) or torch/langchain (via create.py) at import time. Original modules re-export from the new file so existing imports continue to work. `dimos --help` drops from ~3-4s to ~1.9s. * fix: move type aliases to their respective packages NavigationStrategy → dimos/mapping/occupancy/types.py VlModelName → dimos/models/vl/types.py Remove dimos/core/types.py * test: add CLI startup speed regression test Guards against heavy imports (matplotlib, torch, scipy) leaking into the CLI entrypoint via GlobalConfig. Fails if dimos --help takes >8s. * CI code cleanup * fix: exclude .venv and other non-source dirs from doclinks file index build_file_index now skips paths rooted in .venv, node_modules, __pycache__, or .git. Fixes test_excludes_venv failure when .venv is a symlink (not matched by gitignore trailing-slash patterns). * Revert "fix: exclude .venv and other non-source dirs from doclinks file index" This reverts commit 61f8588a58f121cbf72ba1ba7d22d32482a1b0de. --------- Co-authored-by: jeff-hykin <17692058+jeff-hykin@users.noreply.github.com> --- dimos/core/global_config.py | 2 +- dimos/mapping/occupancy/path_map.py | 5 +-- dimos/mapping/occupancy/types.py | 17 ++++++++ dimos/models/vl/create.py | 5 +-- dimos/models/vl/types.py | 3 ++ dimos/robot/cli/test_cli_startup.py | 63 +++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 dimos/mapping/occupancy/types.py create mode 100644 dimos/models/vl/types.py create mode 100644 dimos/robot/cli/test_cli_startup.py diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 42a7fa552a..90461932a2 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -17,7 +17,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict -from dimos.models.vl.create import VlModelName +from dimos.models.vl.types import VlModelName ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] diff --git a/dimos/mapping/occupancy/path_map.py b/dimos/mapping/occupancy/path_map.py index 8920c6e30b..69a1f93738 100644 --- a/dimos/mapping/occupancy/path_map.py +++ b/dimos/mapping/occupancy/path_map.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, TypeAlias - from dimos.mapping.occupancy.gradient import GradientStrategy, gradient, voronoi_gradient from dimos.mapping.occupancy.inflation import simple_inflate from dimos.mapping.occupancy.operations import overlay_occupied, smooth_occupied +from dimos.mapping.occupancy.types import NavigationStrategy from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid -NavigationStrategy: TypeAlias = Literal["simple", "mixed"] - def make_navigation_map( occupancy_grid: OccupancyGrid, diff --git a/dimos/mapping/occupancy/types.py b/dimos/mapping/occupancy/types.py new file mode 100644 index 0000000000..87f2084698 --- /dev/null +++ b/dimos/mapping/occupancy/types.py @@ -0,0 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, TypeAlias + +NavigationStrategy: TypeAlias = Literal["simple", "mixed"] diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 6c778d4104..7fe5a0dcb2 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,9 +1,8 @@ -from typing import Any, Literal +from typing import Any +from dimos.models.vl.types import VlModelName from dimos.models.vl.base import VlModel -VlModelName = Literal["qwen", "moondream"] - def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. diff --git a/dimos/models/vl/types.py b/dimos/models/vl/types.py new file mode 100644 index 0000000000..ac8b0f024d --- /dev/null +++ b/dimos/models/vl/types.py @@ -0,0 +1,3 @@ +from typing import Literal + +VlModelName = Literal["qwen", "moondream"] diff --git a/dimos/robot/cli/test_cli_startup.py b/dimos/robot/cli/test_cli_startup.py new file mode 100644 index 0000000000..aa9886c24e --- /dev/null +++ b/dimos/robot/cli/test_cli_startup.py @@ -0,0 +1,63 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Guard against import-time regressions in the CLI entrypoint. + +`dimos --help` should never pull in heavy ML/viz libraries. If it does, +startup time balloons from <2s to >5s, which is a terrible UX. +""" + +import subprocess +import sys +import time + +# CI runners are slower — give generous headroom but still catch gross regressions. +HELP_TIMEOUT_SECONDS = 8 + + +def test_help_does_not_import_heavy_deps() -> None: + """GlobalConfig import must not drag in matplotlib, torch, or scipy.""" + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import sys; " + "from dimos.core.global_config import GlobalConfig; " + "bad = [m for m in ('matplotlib', 'torch', 'scipy') if m in sys.modules]; " + "assert not bad, f'Heavy deps imported: {bad}'" + ), + ], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0, f"Heavy deps leaked into GlobalConfig import:\n{result.stderr}" + + +def test_help_startup_time() -> None: + """`dimos --help` must finish in under {HELP_TIMEOUT_SECONDS}s.""" + start = time.monotonic() + result = subprocess.run( + [sys.executable, "-m", "dimos.robot.cli.dimos", "--help"], + capture_output=True, + text=True, + timeout=HELP_TIMEOUT_SECONDS + 5, # hard kill safety margin + ) + elapsed = time.monotonic() - start + assert result.returncode == 0, f"dimos --help failed:\n{result.stderr}" + assert elapsed < HELP_TIMEOUT_SECONDS, ( + f"dimos --help took {elapsed:.1f}s (limit: {HELP_TIMEOUT_SECONDS}s). " + f"Check for heavy imports in the CLI entrypoint or GlobalConfig." + ) From dc331b7993f7b20d4dc48e4d66ee84ad692cfd40 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Fri, 20 Mar 2026 10:53:21 +0200 Subject: [PATCH 24/42] chore(blueprints): remove aliases (#1606) --- dimos/agents/agent.py | 5 - dimos/agents/demo_agent.py | 4 +- dimos/agents/mcp/mcp_client.py | 5 - dimos/agents/skills/demo_calculator_skill.py | 5 - dimos/agents/skills/demo_google_maps_skill.py | 12 +- dimos/agents/skills/demo_gps_nav.py | 12 +- dimos/agents/skills/demo_robot.py | 6 - dimos/agents/skills/demo_skill.py | 8 +- .../skills/google_maps_skill_container.py | 5 - dimos/agents/skills/gps_nav_skill.py | 6 - dimos/agents/skills/navigation.py | 5 - dimos/agents/skills/osm.py | 5 - dimos/agents/skills/person_follow.py | 5 - dimos/agents/skills/speak_skill.py | 5 - dimos/agents/vlm_agent.py | 5 - dimos/agents/vlm_stream_tester.py | 5 - dimos/agents/web_human_input.py | 5 - dimos/control/README.md | 4 +- dimos/control/blueprints/basic.py | 12 +- dimos/control/blueprints/dual.py | 8 +- dimos/control/blueprints/mobile.py | 6 +- dimos/control/blueprints/teleop.py | 18 +-- dimos/control/coordinator.py | 13 -- .../examples/twist_base_keyboard_teleop.py | 4 +- dimos/core/test_blueprints.py | 23 ++- dimos/hardware/sensors/camera/module.py | 10 +- .../sensors/camera/realsense/camera.py | 5 - dimos/hardware/sensors/camera/zed/camera.py | 3 - dimos/hardware/sensors/camera/zed/compat.py | 25 ++-- .../lidar/fastlio2/fastlio_blueprints.py | 8 +- .../hardware/sensors/lidar/fastlio2/module.py | 8 - .../sensors/lidar/livox/livox_blueprints.py | 4 +- dimos/hardware/sensors/lidar/livox/module.py | 8 - dimos/manipulation/blueprints.py | 28 ++-- .../cartesian_motion_controller.py | 4 - .../joint_trajectory_controller.py | 4 - dimos/manipulation/grasping/demo_grasping.py | 20 +-- .../manipulation/grasping/graspgen_module.py | 3 - dimos/manipulation/grasping/grasping.py | 4 - dimos/manipulation/manipulation_module.py | 4 - dimos/manipulation/pick_and_place_module.py | 4 - dimos/mapping/costmapper.py | 3 - dimos/mapping/osm/demo_osm.py | 12 +- dimos/mapping/voxels.py | 3 - .../wavefront_frontier_goal_selector.py | 5 - dimos/navigation/replanning_a_star/module.py | 5 - dimos/navigation/rosnav.py | 6 - .../demo_object_scene_registration.py | 20 +-- dimos/perception/detection/module3D.py | 5 - dimos/perception/detection/moduleDB.py | 5 - dimos/perception/detection/person_tracker.py | 5 - .../temporal_memory/temporal_memory.py | 7 +- .../temporal_memory/temporal_utils/helpers.py | 2 +- dimos/perception/object_scene_registration.py | 5 - dimos/perception/object_tracker.py | 5 - dimos/perception/spatial_perception.py | 5 - dimos/robot/all_blueprints.py | 58 ++++--- .../drone/blueprints/agentic/drone_agentic.py | 8 +- .../drone/blueprints/basic/drone_basic.py | 12 +- dimos/robot/foxglove_bridge.py | 6 - dimos/robot/manipulators/piper/blueprints.py | 12 +- dimos/robot/manipulators/xarm/blueprints.py | 18 +-- dimos/robot/test_all_blueprints.py | 1 + dimos/robot/test_all_blueprints_generation.py | 141 ++++++++++++++---- .../g1/blueprints/agentic/_agentic_skills.py | 16 +- .../g1/blueprints/agentic/unitree_g1_full.py | 4 +- .../g1/blueprints/basic/unitree_g1_basic.py | 8 +- .../blueprints/basic/unitree_g1_basic_sim.py | 8 +- .../blueprints/basic/unitree_g1_joystick.py | 4 +- .../perceptive/_perception_and_memory.py | 8 +- .../perceptive/unitree_g1_detection.py | 12 +- .../blueprints/perceptive/unitree_g1_shm.py | 4 +- .../primitive/uintree_g1_primitive_no_nav.py | 30 ++-- dimos/robot/unitree/g1/connection.py | 6 - dimos/robot/unitree/g1/sim.py | 6 - dimos/robot/unitree/g1/skill_container.py | 4 - .../go2/blueprints/agentic/_common_agentic.py | 20 +-- .../blueprints/agentic/unitree_go2_agentic.py | 4 +- .../unitree_go2_agentic_huggingface.py | 4 +- .../agentic/unitree_go2_agentic_mcp.py | 4 +- .../agentic/unitree_go2_agentic_ollama.py | 4 +- .../agentic/unitree_go2_temporal_memory.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 17 ++- .../go2/blueprints/basic/unitree_go2_fleet.py | 8 +- .../go2/blueprints/smart/unitree_go2.py | 16 +- .../blueprints/smart/unitree_go2_detection.py | 4 +- .../blueprints/smart/unitree_go2_spatial.py | 4 +- .../smart/unitree_go2_vlm_stream_test.py | 8 +- dimos/robot/unitree/go2/connection.py | 6 - dimos/robot/unitree/go2/fleet_connection.py | 6 - dimos/robot/unitree/keyboard_teleop.py | 5 - dimos/robot/unitree/type/map.py | 6 - .../robot/unitree/unitree_skill_container.py | 5 - dimos/simulation/manipulators/sim_module.py | 9 -- dimos/simulation/sim_blueprints.py | 5 +- .../teleop/keyboard/keyboard_teleop_module.py | 3 - dimos/teleop/phone/blueprints.py | 8 +- dimos/teleop/phone/phone_extensions.py | 8 - dimos/teleop/phone/phone_teleop_module.py | 9 -- dimos/teleop/quest/blueprints.py | 14 +- dimos/teleop/quest/quest_extensions.py | 13 -- dimos/teleop/quest/quest_teleop_module.py | 11 -- dimos/visualization/rerun/bridge.py | 4 - .../web/websocket_vis/websocket_vis_module.py | 5 - .../manipulation/adding_a_custom_arm.md | 4 +- pyproject.toml | 2 +- 106 files changed, 409 insertions(+), 603 deletions(-) diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 672d30c3de..cce8148e2e 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -259,8 +259,3 @@ def _append_image_to_history(agent: Agent, skill: SkillInfo, uuid_: str, result: ] ) ) - - -agent = Agent.blueprint - -__all__ = ["Agent", "AgentSpec", "agent"] diff --git a/dimos/agents/demo_agent.py b/dimos/agents/demo_agent.py index b839b0809c..29396f3cfa 100644 --- a/dimos/agents/demo_agent.py +++ b/dimos/agents/demo_agent.py @@ -14,7 +14,7 @@ from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.module import camera_module +from dimos.hardware.sensors.camera.module import CameraModule from dimos.hardware.sensors.camera.webcam import Webcam from dimos.hardware.sensors.camera.zed import compat as zed @@ -31,7 +31,7 @@ def _create_webcam() -> Webcam: demo_agent_camera = autoconnect( Agent.blueprint(), - camera_module( + CameraModule.blueprint( hardware=_create_webcam, ), ) diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index b32d195de8..e0200a6323 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -304,8 +304,3 @@ def _append_image_to_history( ] ) ) - - -mcp_client = McpClient.blueprint - -__all__ = ["McpClient", "McpClientConfig", "mcp_client"] diff --git a/dimos/agents/skills/demo_calculator_skill.py b/dimos/agents/skills/demo_calculator_skill.py index 61d66e301a..6c0605bb7c 100644 --- a/dimos/agents/skills/demo_calculator_skill.py +++ b/dimos/agents/skills/demo_calculator_skill.py @@ -36,8 +36,3 @@ def sum_numbers(self, n1: int, n2: int, *args: int, **kwargs: int) -> str: """ return f"{int(n1) + int(n2)}" - - -demo_calculator_skill = DemoCalculatorSkill.blueprint - -__all__ = ["DemoCalculatorSkill", "demo_calculator_skill"] diff --git a/dimos/agents/skills/demo_google_maps_skill.py b/dimos/agents/skills/demo_google_maps_skill.py index 13f2ebc19b..616bbac90a 100644 --- a/dimos/agents/skills/demo_google_maps_skill.py +++ b/dimos/agents/skills/demo_google_maps_skill.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent -from dimos.agents.skills.demo_robot import demo_robot -from dimos.agents.skills.google_maps_skill_container import google_maps_skill +from dimos.agents.agent import Agent +from dimos.agents.skills.demo_robot import DemoRobot +from dimos.agents.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.core.blueprints import autoconnect demo_google_maps_skill = autoconnect( - demo_robot(), - google_maps_skill(), - agent(), + DemoRobot.blueprint(), + GoogleMapsSkillContainer.blueprint(), + Agent.blueprint(), ) diff --git a/dimos/agents/skills/demo_gps_nav.py b/dimos/agents/skills/demo_gps_nav.py index 7a6abd32dd..4810fc3883 100644 --- a/dimos/agents/skills/demo_gps_nav.py +++ b/dimos/agents/skills/demo_gps_nav.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent -from dimos.agents.skills.demo_robot import demo_robot -from dimos.agents.skills.gps_nav_skill import gps_nav_skill +from dimos.agents.agent import Agent +from dimos.agents.skills.demo_robot import DemoRobot +from dimos.agents.skills.gps_nav_skill import GpsNavSkillContainer from dimos.core.blueprints import autoconnect demo_gps_nav = autoconnect( - demo_robot(), - gps_nav_skill(), - agent(), + DemoRobot.blueprint(), + GpsNavSkillContainer.blueprint(), + Agent.blueprint(), ) diff --git a/dimos/agents/skills/demo_robot.py b/dimos/agents/skills/demo_robot.py index 789e26d7e1..2917ec2d76 100644 --- a/dimos/agents/skills/demo_robot.py +++ b/dimos/agents/skills/demo_robot.py @@ -32,9 +32,3 @@ def stop(self) -> None: def _publish_gps_location(self) -> None: self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) - - -demo_robot = DemoRobot.blueprint - - -__all__ = ["DemoRobot", "demo_robot"] diff --git a/dimos/agents/skills/demo_skill.py b/dimos/agents/skills/demo_skill.py index b067a3fbc2..81935d25b8 100644 --- a/dimos/agents/skills/demo_skill.py +++ b/dimos/agents/skills/demo_skill.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent -from dimos.agents.skills.demo_calculator_skill import demo_calculator_skill +from dimos.agents.agent import Agent +from dimos.agents.skills.demo_calculator_skill import DemoCalculatorSkill from dimos.core.blueprints import autoconnect demo_skill = autoconnect( - demo_calculator_skill(), - agent(), + DemoCalculatorSkill.blueprint(), + Agent.blueprint(), ) diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index e218601696..ee48e51653 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -124,8 +124,3 @@ def get_gps_position_for_queries(self, queries: list[str]) -> str: results.append(f"no result for {query}") return json.dumps(results) - - -google_maps_skill = GoogleMapsSkillContainer.blueprint - -__all__ = ["GoogleMapsSkillContainer", "google_maps_skill"] diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 1464665131..c6f86951be 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -98,9 +98,3 @@ def _convert_point(self, point: dict[str, float]) -> LatLon | None: return None return LatLon(lat=lat, lon=lon) - - -gps_nav_skill = GpsNavSkillContainer.blueprint - - -__all__ = ["GpsNavSkillContainer", "gps_nav_skill"] diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index 47ae21c799..e366465959 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -321,8 +321,3 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | No orientation=Quaternion.from_euler(make_vector3(0, 0, theta)), frame_id="map", ) - - -navigation_skill = NavigationSkillContainer.blueprint - -__all__ = ["NavigationSkillContainer", "navigation_skill"] diff --git a/dimos/agents/skills/osm.py b/dimos/agents/skills/osm.py index d0281fb808..a89e86044f 100644 --- a/dimos/agents/skills/osm.py +++ b/dimos/agents/skills/osm.py @@ -78,8 +78,3 @@ def map_query(self, query_sentence: str) -> str: distance = int(distance_in_meters(latlon, self._latest_location)) # type: ignore[arg-type] return f"{context}. It's at position latitude={latlon.lat}, longitude={latlon.lon}. It is {distance} meters away." - - -osm_skill = OsmSkill.blueprint - -__all__ = ["OsmSkill", "osm_skill"] diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 563fcd4f59..56d3de62d3 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -311,8 +311,3 @@ def _send_stop_reason(self, query: str, reason: str) -> None: def _decode_base64_image(b64: str) -> Image: bgr_array = TurboJPEG().decode(base64.b64decode(b64)) return Image(data=bgr_array, format=ImageFormat.BGR) - - -person_follow_skill = PersonFollowSkillContainer.blueprint - -__all__ = ["PersonFollowSkillContainer", "person_follow_skill"] diff --git a/dimos/agents/skills/speak_skill.py b/dimos/agents/skills/speak_skill.py index aa06d30ba4..802aec03d0 100644 --- a/dimos/agents/skills/speak_skill.py +++ b/dimos/agents/skills/speak_skill.py @@ -97,8 +97,3 @@ def set_as_complete_e(_e: Exception) -> None: subscription.dispose() return f"Spoke: {text}" - - -speak_skill = SpeakSkill.blueprint - -__all__ = ["SpeakSkill", "speak_skill"] diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 81bad79ae5..114302b397 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -121,8 +121,3 @@ def query_image( response = self._invoke_image(image, query, response_format=response_format) content = response.content return content if isinstance(content, str) else str(content) - - -vlm_agent = VLMAgent.blueprint - -__all__ = ["VLMAgent", "vlm_agent"] diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py index 5f2165dc8d..80353dbfe0 100644 --- a/dimos/agents/vlm_stream_tester.py +++ b/dimos/agents/vlm_stream_tester.py @@ -173,8 +173,3 @@ def _run_rpc_queries(self) -> None: except Exception as exc: logger.warning("RPC query_image failed", error=str(exc)) time.sleep(self._query_interval_s) - - -vlm_stream_tester = VlmStreamTester.blueprint - -__all__ = ["VlmStreamTester", "vlm_stream_tester"] diff --git a/dimos/agents/web_human_input.py b/dimos/agents/web_human_input.py index 22fdb231b3..2b84736d27 100644 --- a/dimos/agents/web_human_input.py +++ b/dimos/agents/web_human_input.py @@ -84,8 +84,3 @@ def stop(self) -> None: if self._human_transport: self._human_transport.lcm.stop() super().stop() - - -web_input = WebInput.blueprint - -__all__ = ["WebInput", "web_input"] diff --git a/dimos/control/README.md b/dimos/control/README.md index 755bfbd939..15303b421d 100644 --- a/dimos/control/README.md +++ b/dimos/control/README.md @@ -96,9 +96,9 @@ dimos/control/ ## Configuration ```python -from dimos.control import control_coordinator, HardwareComponent, TaskConfig +from dimos.control import ControlCoordinator, HardwareComponent, TaskConfig -my_robot = control_coordinator( +my_robot = ControlCoordinator.blueprint( tick_rate=100.0, hardware=[ HardwareComponent( diff --git a/dimos/control/blueprints/basic.py b/dimos/control/blueprints/basic.py index 7ad441ed70..58619f6fc3 100644 --- a/dimos/control/blueprints/basic.py +++ b/dimos/control/blueprints/basic.py @@ -24,12 +24,12 @@ from __future__ import annotations from dimos.control.blueprints._hardware import mock_arm, piper, xarm6, xarm7 -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.transport import LCMTransport from dimos.msgs.sensor_msgs.JointState import JointState # Minimal blueprint (no hardware, no tasks) -coordinator_basic = control_coordinator( +coordinator_basic = ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, joint_state_frame_id="coordinator", @@ -40,7 +40,7 @@ ) # Mock 7-DOF arm (for testing) -coordinator_mock = control_coordinator( +coordinator_mock = ControlCoordinator.blueprint( hardware=[mock_arm()], tasks=[ TaskConfig( @@ -57,7 +57,7 @@ ) # XArm7 real hardware -coordinator_xarm7 = control_coordinator( +coordinator_xarm7 = ControlCoordinator.blueprint( hardware=[xarm7()], tasks=[ TaskConfig( @@ -74,7 +74,7 @@ ) # XArm6 real hardware -coordinator_xarm6 = control_coordinator( +coordinator_xarm6 = ControlCoordinator.blueprint( hardware=[xarm6()], tasks=[ TaskConfig( @@ -91,7 +91,7 @@ ) # Piper arm (6-DOF, CAN bus) -coordinator_piper = control_coordinator( +coordinator_piper = ControlCoordinator.blueprint( hardware=[piper()], tasks=[ TaskConfig( diff --git a/dimos/control/blueprints/dual.py b/dimos/control/blueprints/dual.py index 8482316ba5..057e982f90 100644 --- a/dimos/control/blueprints/dual.py +++ b/dimos/control/blueprints/dual.py @@ -23,12 +23,12 @@ from __future__ import annotations from dimos.control.blueprints._hardware import mock_arm, piper, xarm6, xarm7 -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.transport import LCMTransport from dimos.msgs.sensor_msgs.JointState import JointState # Dual mock arms (7-DOF left, 6-DOF right) -coordinator_dual_mock = control_coordinator( +coordinator_dual_mock = ControlCoordinator.blueprint( hardware=[mock_arm("left_arm", 7), mock_arm("right_arm", 6)], tasks=[ TaskConfig( @@ -51,7 +51,7 @@ ) # Dual XArm (XArm7 left, XArm6 right) -coordinator_dual_xarm = control_coordinator( +coordinator_dual_xarm = ControlCoordinator.blueprint( hardware=[xarm7("left_arm"), xarm6("right_arm")], tasks=[ TaskConfig( @@ -74,7 +74,7 @@ ) # Dual arm (XArm6 + Piper) -coordinator_piper_xarm = control_coordinator( +coordinator_piper_xarm = ControlCoordinator.blueprint( hardware=[xarm6("xarm_arm"), piper("piper_arm")], tasks=[ TaskConfig( diff --git a/dimos/control/blueprints/mobile.py b/dimos/control/blueprints/mobile.py index 4ed3410b8f..5e5e1966b8 100644 --- a/dimos/control/blueprints/mobile.py +++ b/dimos/control/blueprints/mobile.py @@ -23,7 +23,7 @@ from dimos.control.blueprints._hardware import mock_arm, mock_twist_base from dimos.control.components import make_twist_base_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.transport import LCMTransport from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.JointState import JointState @@ -31,7 +31,7 @@ _base_joints = make_twist_base_joints("base") # Mock holonomic twist base (3-DOF: vx, vy, wz) -coordinator_mock_twist_base = control_coordinator( +coordinator_mock_twist_base = ControlCoordinator.blueprint( hardware=[mock_twist_base()], tasks=[ TaskConfig( @@ -49,7 +49,7 @@ ) # Mock arm (7-DOF) + mock holonomic base (3-DOF) -coordinator_mobile_manip_mock = control_coordinator( +coordinator_mobile_manip_mock = ControlCoordinator.blueprint( hardware=[mock_arm(), mock_twist_base()], tasks=[ TaskConfig( diff --git a/dimos/control/blueprints/teleop.py b/dimos/control/blueprints/teleop.py index 2e922bbcbf..1dfa55d80d 100644 --- a/dimos/control/blueprints/teleop.py +++ b/dimos/control/blueprints/teleop.py @@ -37,14 +37,14 @@ xarm7, ) from dimos.control.components import make_gripper_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.transport import LCMTransport from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import Buttons # XArm6 teleop - streaming position control -coordinator_teleop_xarm6 = control_coordinator( +coordinator_teleop_xarm6 = ControlCoordinator.blueprint( hardware=[xarm6()], tasks=[ TaskConfig( @@ -62,7 +62,7 @@ ) # XArm6 velocity control - streaming velocity for joystick -coordinator_velocity_xarm6 = control_coordinator( +coordinator_velocity_xarm6 = ControlCoordinator.blueprint( hardware=[xarm6()], tasks=[ TaskConfig( @@ -80,7 +80,7 @@ ) # XArm6 combined (servo + velocity tasks) -coordinator_combined_xarm6 = control_coordinator( +coordinator_combined_xarm6 = ControlCoordinator.blueprint( hardware=[xarm6()], tasks=[ TaskConfig( @@ -105,7 +105,7 @@ # Mock 6-DOF arm with CartesianIK -coordinator_cartesian_ik_mock = control_coordinator( +coordinator_cartesian_ik_mock = ControlCoordinator.blueprint( hardware=[mock_arm("arm", 6)], tasks=[ TaskConfig( @@ -127,7 +127,7 @@ ) # Piper arm with CartesianIK -coordinator_cartesian_ik_piper = control_coordinator( +coordinator_cartesian_ik_piper = ControlCoordinator.blueprint( hardware=[piper()], tasks=[ TaskConfig( @@ -150,7 +150,7 @@ # Single XArm7 with TeleopIK -coordinator_teleop_xarm7 = control_coordinator( +coordinator_teleop_xarm7 = ControlCoordinator.blueprint( hardware=[xarm7(gripper=True)], tasks=[ TaskConfig( @@ -177,7 +177,7 @@ ) # Single Piper with TeleopIK -coordinator_teleop_piper = control_coordinator( +coordinator_teleop_piper = ControlCoordinator.blueprint( hardware=[piper()], tasks=[ TaskConfig( @@ -201,7 +201,7 @@ ) # Dual arm teleop: XArm6 + Piper with TeleopIK -coordinator_teleop_dual = control_coordinator( +coordinator_teleop_dual = ControlCoordinator.blueprint( hardware=[xarm6("xarm_arm"), piper("piper_arm")], tasks=[ TaskConfig( diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 0757f27705..9f3264f85c 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -722,16 +722,3 @@ def stop(self) -> None: def get_tick_count(self) -> int: """Get the number of ticks since start.""" return self._tick_loop.tick_count if self._tick_loop else 0 - - -# Blueprint export -control_coordinator = ControlCoordinator.blueprint - - -__all__ = [ - "ControlCoordinator", - "ControlCoordinatorConfig", - "HardwareComponent", - "TaskConfig", - "control_coordinator", -] diff --git a/dimos/control/examples/twist_base_keyboard_teleop.py b/dimos/control/examples/twist_base_keyboard_teleop.py index 610f8679e4..44cd34c354 100644 --- a/dimos/control/examples/twist_base_keyboard_teleop.py +++ b/dimos/control/examples/twist_base_keyboard_teleop.py @@ -34,13 +34,13 @@ from __future__ import annotations from dimos.control.blueprints.mobile import coordinator_mock_twist_base -from dimos.robot.unitree.keyboard_teleop import keyboard_teleop +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop def main() -> None: """Run mock twist base + keyboard teleop.""" coord = coordinator_mock_twist_base.build() - teleop = keyboard_teleop().build() + teleop = KeyboardTeleop.blueprint().build() print("Starting mock twist base coordinator + keyboard teleop...") print("Coordinator tick loop: 100Hz") diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 5f7bf33b8b..b61e34c5f9 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -107,11 +107,6 @@ class ModuleC(Module): data3: In[Data3] -module_a = ModuleA.blueprint -module_b = ModuleB.blueprint -module_c = ModuleC.blueprint - - def test_get_connection_set() -> None: assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom( module=CatModule, @@ -125,7 +120,7 @@ def test_get_connection_set() -> None: def test_autoconnect() -> None: - blueprint_set = autoconnect(module_a(), module_b()) + blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()) assert blueprint_set == Blueprint( blueprints=( @@ -154,7 +149,7 @@ def test_autoconnect() -> None: def test_transports() -> None: custom_transport = LCMTransport("/custom_topic", Data1) - blueprint_set = autoconnect(module_a(), module_b()).transports( + blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()).transports( {("data1", Data1): custom_transport} ) @@ -163,7 +158,9 @@ def test_transports() -> None: def test_global_config() -> None: - blueprint_set = autoconnect(module_a(), module_b()).global_config(option1=True, option2=42) + blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()).global_config( + option1=True, option2=42 + ) assert "option1" in blueprint_set.global_config_overrides assert blueprint_set.global_config_overrides["option1"] is True @@ -173,7 +170,7 @@ def test_global_config() -> None: @pytest.mark.slow def test_build_happy_path() -> None: - blueprint_set = autoconnect(module_a(), module_b(), module_c()) + blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint(), ModuleC.blueprint()) coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) @@ -475,7 +472,9 @@ def test_module_ref_spec() -> None: @pytest.mark.slow def test_disabled_modules_are_skipped_during_build() -> None: - blueprint_set = autoconnect(module_a(), module_b(), module_c()).disabled_modules(ModuleC) + blueprint_set = autoconnect( + ModuleA.blueprint(), ModuleB.blueprint(), ModuleC.blueprint() + ).disabled_modules(ModuleC) coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) @@ -490,11 +489,11 @@ def test_disabled_modules_are_skipped_during_build() -> None: def test_autoconnect_merges_disabled_modules() -> None: bp_a = Blueprint( - blueprints=module_a().blueprints, + blueprints=ModuleA.blueprint().blueprints, disabled_modules_tuple=(ModuleA,), ) bp_b = Blueprint( - blueprints=module_b().blueprints, + blueprints=ModuleB.blueprint().blueprints, disabled_modules_tuple=(ModuleB,), ) diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index e0d0b3407e..b8165658d9 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -32,7 +32,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import rerun_bridge +from dimos.visualization.rerun.bridge import RerunBridgeModule def default_transform() -> Transform: @@ -118,11 +118,7 @@ def stop(self) -> None: super().stop() -camera_module = CameraModule.blueprint - demo_camera = autoconnect( - camera_module(), - rerun_bridge(), + CameraModule.blueprint(), + RerunBridgeModule.blueprint(), ) - -__all__ = ["CameraModule", "camera_module"] diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index 23bc19cdad..821982981d 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -479,8 +479,3 @@ def cleanup() -> None: if __name__ == "__main__": main() - - -realsense_camera = RealSenseCamera.blueprint - -__all__ = ["RealSenseCamera", "RealSenseCameraConfig", "realsense_camera"] diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index 214b1f73e3..dd429c29cf 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -528,6 +528,3 @@ def cleanup() -> None: ZEDModule = ZEDCamera -zed_camera = ZEDCamera.blueprint - -__all__ = ["ZEDCamera", "ZEDCameraConfig", "ZEDModule", "zed_camera"] diff --git a/dimos/hardware/sensors/camera/zed/compat.py b/dimos/hardware/sensors/camera/zed/compat.py index 3cec8d9566..c00971e471 100644 --- a/dimos/hardware/sensors/camera/zed/compat.py +++ b/dimos/hardware/sensors/camera/zed/compat.py @@ -28,26 +28,24 @@ HAS_ZED_SDK = False if HAS_ZED_SDK: - from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera + from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule else: # Provide stub classes when SDK is not available + _ZED_ERR = ( + "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." + ) + class ZEDCamera: # type: ignore[no-redef] def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - raise ImportError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." - ) + raise ImportError(_ZED_ERR) + + @classmethod + def blueprint(cls, *args: object, **kwargs: object) -> None: # type: ignore[no-untyped-def] + raise ImportError(_ZED_ERR) class ZEDModule: # type: ignore[no-redef] def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - raise ImportError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." - ) - - def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[misc,no-redef] - raise ModuleNotFoundError( - "ZED SDK not installed. Please install pyzed package to use ZED camera functionality.", - name="pyzed", - ) + raise ImportError(_ZED_ERR) # Set up camera calibration provider (always available) @@ -59,5 +57,4 @@ def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[misc,no "CameraInfo", "ZEDCamera", "ZEDModule", - "zed_camera", ] diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index b1a6baef44..f3de842b46 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,13 +15,13 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import rerun_bridge +from dimos.visualization.rerun.bridge import RerunBridgeModule voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - rerun_bridge( + RerunBridgeModule.blueprint( visual_override={ "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), } @@ -31,7 +31,7 @@ mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), - rerun_bridge( + RerunBridgeModule.blueprint( visual_override={ "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), "world/lidar": None, @@ -41,7 +41,7 @@ mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - rerun_bridge( + RerunBridgeModule.blueprint( visual_override={ "world/lidar": None, "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index c1a96a525b..cdce59bd81 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -132,14 +132,6 @@ class FastLio2( global_map: Out[PointCloud2] -fastlio2_module = FastLio2.blueprint - -__all__ = [ - "FastLio2", - "FastLio2Config", - "fastlio2_module", -] - # Verify protocol port compliance (mypy will flag missing ports) if TYPE_CHECKING: FastLio2() diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index 9ded4578ba..c8835b3e89 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import rerun_bridge +from dimos.visualization.rerun.bridge import RerunBridgeModule mid360 = autoconnect( Mid360.blueprint(), - rerun_bridge(), + RerunBridgeModule.blueprint(), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 999cdd9aa1..5701a4b4d4 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -89,14 +89,6 @@ class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): imu: Out[Imu] -mid360_module = Mid360.blueprint - -__all__ = [ - "Mid360", - "Mid360Config", - "mid360_module", -] - # Verify protocol port compliance (mypy will flag missing ports) if TYPE_CHECKING: Mid360() diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 8ef2c03279..8110166042 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -32,20 +32,20 @@ from dimos.agents.agent import Agent from dimos.control.components import HardwareComponent, HardwareType, make_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.realsense.camera import realsense_camera -from dimos.manipulation.manipulation_module import manipulation_module -from dimos.manipulation.pick_and_place_module import pick_and_place_module +from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera +from dimos.manipulation.manipulation_module import ManipulationModule +from dimos.manipulation.pick_and_place_module import PickAndPlaceModule from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.perception.object_scene_registration import object_scene_registration_module -from dimos.robot.foxglove_bridge import foxglove_bridge # TODO: migrate to rerun +from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule +from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun from dimos.utils.data import get_data @@ -273,7 +273,7 @@ def _make_piper_config( # Single XArm6 planner (standalone, no coordinator) -xarm6_planner_only = manipulation_module( +xarm6_planner_only = ManipulationModule.blueprint( robots=[_make_xarm6_config()], planning_timeout=10.0, enable_viz=True, @@ -286,7 +286,7 @@ def _make_piper_config( # Dual XArm6 planner with coordinator integration # Usage: Start with coordinator_dual_mock, then plan/execute via RPC -dual_xarm6_planner = manipulation_module( +dual_xarm6_planner = ManipulationModule.blueprint( robots=[ _make_xarm6_config( "left_arm", y_offset=0.5, joint_prefix="left_", coordinator_task="traj_left" @@ -307,12 +307,12 @@ def _make_piper_config( # Single XArm7 planner + mock coordinator (standalone, no external coordinator needed) # Usage: dimos run xarm7-planner-coordinator xarm7_planner_coordinator = autoconnect( - manipulation_module( + ManipulationModule.blueprint( robots=[_make_xarm7_config("arm", joint_prefix="arm_", coordinator_task="traj_arm")], planning_timeout=10.0, enable_viz=True, ), - control_coordinator( + ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, joint_state_frame_id="coordinator", @@ -387,7 +387,7 @@ def _make_piper_config( xarm_perception = ( autoconnect( - pick_and_place_module( + PickAndPlaceModule.blueprint( robots=[ _make_xarm7_config( "arm", @@ -402,12 +402,12 @@ def _make_piper_config( planning_timeout=10.0, enable_viz=True, ), - realsense_camera( + RealSenseCamera.blueprint( base_frame_id="link7", base_transform=_XARM_PERCEPTION_CAMERA_TRANSFORM, ), - object_scene_registration_module(target_frame="world"), - foxglove_bridge(), # TODO: migrate to rerun + ObjectSceneRegistrationModule.blueprint(target_frame="world"), + FoxgloveBridge.blueprint(), # TODO: migrate to rerun ) .transports( { diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py index 0cbd41e218..6b702495e2 100644 --- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py +++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py @@ -708,7 +708,3 @@ def _integrate_velocity(self, current_pose: Pose, velocity: Twist, dt: float) -> def _normalize_angle(angle: float) -> float: """Normalize angle to [-pi, pi].""" return math.atan2(math.sin(angle), math.cos(angle)) - - -# Expose blueprint for declarative composition -cartesian_motion_controller = CartesianMotionController.blueprint diff --git a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py index 465df7afea..a91e1bfb11 100644 --- a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py +++ b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py @@ -351,7 +351,3 @@ def _execution_loop(self) -> None: time.sleep(period) logger.info("Execution loop stopped") - - -# Expose blueprint for declarative composition -joint_trajectory_controller = JointTrajectoryController.blueprint diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 43a6c9a20a..a4eea21787 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -14,23 +14,23 @@ # limitations under the License. from pathlib import Path -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense.camera import realsense_camera +from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera from dimos.manipulation.grasping.graspgen_module import graspgen -from dimos.manipulation.grasping.grasping import grasping_module +from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode -from dimos.perception.object_scene_registration import object_scene_registration_module -from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule +from dimos.robot.foxglove_bridge import FoxgloveBridge -camera_module = realsense_camera(enable_pointcloud=False) +camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) demo_grasping = autoconnect( camera_module, - object_scene_registration_module( + ObjectSceneRegistrationModule.blueprint( target_frame="camera_color_optical_frame", prompt_mode=YoloePromptMode.PROMPT ), - grasping_module(), + GraspingModule.blueprint(), graspgen( docker_file_path=Path(__file__).parent / "docker_context" / "Dockerfile", docker_build_context=Path(__file__).parent.parent.parent.parent, # repo root @@ -43,6 +43,6 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - foxglove_bridge(), - agent(), + FoxgloveBridge.blueprint(), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/manipulation/grasping/graspgen_module.py b/dimos/manipulation/grasping/graspgen_module.py index c883126840..ae2d59512a 100644 --- a/dimos/manipulation/grasping/graspgen_module.py +++ b/dimos/manipulation/grasping/graspgen_module.py @@ -270,6 +270,3 @@ def graspgen( return GraspGenModule.blueprint( docker_file=dockerfile, docker_build_context=build_context, **kwargs ) - - -__all__ = ["GraspGenConfig", "GraspGenModule", "graspgen"] diff --git a/dimos/manipulation/grasping/grasping.py b/dimos/manipulation/grasping/grasping.py index ef05dc29e2..50671777c0 100644 --- a/dimos/manipulation/grasping/grasping.py +++ b/dimos/manipulation/grasping/grasping.py @@ -145,7 +145,3 @@ def _format_grasp_result(self, grasps: PoseArray, object_name: str) -> str: f"Best grasp: pos=({pos.x:.4f}, {pos.y:.4f}, {pos.z:.4f}), " f"rpy=({rpy.x:.1f}, {rpy.y:.1f}, {rpy.z:.1f}) degrees" ) - - -grasping_module = GraspingModule.blueprint -__all__ = ["GraspingModule", "grasping_module"] diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index fe5561c705..d6908d07d9 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -1121,7 +1121,3 @@ def stop(self) -> None: self._world_monitor.stop_all_monitors() super().stop() - - -# Expose blueprint for declarative composition -manipulation_module = ManipulationModule.blueprint diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index b433df6801..81e7bcf2d3 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -592,7 +592,3 @@ def stop(self) -> None: self._graspgen = None super().stop() - - -# Expose blueprint for declarative composition -pick_and_place_module = PickAndPlaceModule.blueprint diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 06bf493564..87ed64d404 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -74,6 +74,3 @@ def stop(self) -> None: def _calculate_costmap(self, msg: PointCloud2) -> OccupancyGrid: fn = OCCUPANCY_ALGOS[self.config.algo] return fn(msg, **asdict(self.config.config)) - - -cost_mapper = CostMapper.blueprint diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py index 97622cfaf2..54b6ab39a3 100644 --- a/dimos/mapping/osm/demo_osm.py +++ b/dimos/mapping/osm/demo_osm.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent -from dimos.agents.skills.demo_robot import demo_robot -from dimos.agents.skills.osm import osm_skill +from dimos.agents.agent import Agent +from dimos.agents.skills.demo_robot import DemoRobot +from dimos.agents.skills.osm import OsmSkill from dimos.core.blueprints import autoconnect demo_osm = autoconnect( - demo_robot(), - osm_skill(), - agent(), + DemoRobot.blueprint(), + OsmSkill.blueprint(), + Agent.blueprint(), ) diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index e4e03dfc01..92cbeed03e 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -241,6 +241,3 @@ def ensure_legacy_pcd( ) return pcd_any.to_legacy() - - -voxel_mapper = VoxelGridMapper.blueprint diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 20fab41b35..e2f408b538 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -840,8 +840,3 @@ def end_exploration(self) -> str: return "Stopped exploration. The robot has stopped moving." else: return "Exploration skill was not active, so nothing was stopped." - - -wavefront_frontier_explorer = WavefrontFrontierExplorer.blueprint - -__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 842a6319d4..26c540a254 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -119,8 +119,3 @@ def set_safe_goal_clearance(self, clearance: float) -> None: @rpc def reset_safe_goal_clearance(self) -> None: self._planner.reset_safe_goal_clearance() - - -replanning_a_star_planner = ReplanningAStarPlanner.blueprint - -__all__ = ["ReplanningAStarPlanner", "replanning_a_star_planner"] diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 38c8e32847..ef76539d5f 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -381,9 +381,6 @@ def stop(self) -> None: super().stop() -ros_nav = ROSNav.blueprint - - def deploy(dimos: ModuleCoordinator): # type: ignore[no-untyped-def] nav = dimos.deploy(ROSNav) # type: ignore[attr-defined] @@ -412,6 +409,3 @@ def deploy(dimos: ModuleCoordinator): # type: ignore[no-untyped-def] nav.start() return nav - - -__all__ = ["ROSNav", "deploy", "ros_nav"] diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index cdb09d359e..55b26f385a 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -13,26 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense.camera import realsense_camera -from dimos.hardware.sensors.camera.zed.compat import zed_camera +from dimos.hardware.sensors.camera.realsense.camera import RealSenseCamera +from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode -from dimos.perception.object_scene_registration import object_scene_registration_module -from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule +from dimos.robot.foxglove_bridge import FoxgloveBridge camera_choice = "zed" if camera_choice == "realsense": - camera_module = realsense_camera(enable_pointcloud=False) + camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) elif camera_choice == "zed": - camera_module = zed_camera(enable_pointcloud=False) + camera_module = ZEDCamera.blueprint(enable_pointcloud=False) else: raise ValueError(f"Invalid camera choice: {camera_choice}") demo_object_scene_registration = autoconnect( camera_module, - object_scene_registration_module(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - foxglove_bridge(), - agent(), + ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), + FoxgloveBridge.blueprint(), + Agent.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index fa392dc799..66771a4e2a 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -232,8 +232,3 @@ def deploy( # type: ignore[no-untyped-def] detector.start() return detector - - -detection3d_module = Detection3DModule.blueprint - -__all__ = ["Detection3DModule", "deploy", "detection3d_module"] diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 5672786b94..81e73b0b04 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -311,8 +311,3 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": def __len__(self) -> int: return len(self.objects.values()) - - -detection_db_module = ObjectDBModule.blueprint - -__all__ = ["ObjectDBModule", "detection_db_module"] diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py index 9dbba210a2..0113135adf 100644 --- a/dimos/perception/detection/person_tracker.py +++ b/dimos/perception/detection/person_tracker.py @@ -124,8 +124,3 @@ def track(self, detections2D: ImageDetections2D) -> None: pose_in_world = tf_world_to_target.to_pose(ts=detections2D.ts) self.target.publish(pose_in_world) - - -person_tracker_module = PersonTracker.blueprint - -__all__ = ["PersonTracker", "person_tracker_module"] diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index d4e343872b..da9fe62370 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -44,7 +44,7 @@ from .clip_filter import CLIP_AVAILABLE, adaptive_keyframes from .entity_graph_db import EntityGraphDB -from .frame_window_accumulator import Frame, FrameWindowAccumulator +from .frame_window_accumulator import FrameWindowAccumulator from .temporal_state import TemporalState from .temporal_utils.graph_utils import build_graph_context, extract_time_window from .temporal_utils.helpers import is_scene_stale @@ -624,8 +624,3 @@ def get_graph_db_stats(self) -> dict[str, Any]: if not self._graph_db: return {"stats": {}, "entities": [], "recent_relations": []} return self._graph_db.get_summary() - - -temporal_memory = TemporalMemory.blueprint - -__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "temporal_memory"] diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py index 513feb65a4..88ddee1157 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py @@ -19,7 +19,7 @@ import numpy as np if TYPE_CHECKING: - from ..temporal_memory import Frame + from ..frame_window_accumulator import Frame def next_entity_id_hint(roster: Any) -> str: diff --git a/dimos/perception/object_scene_registration.py b/dimos/perception/object_scene_registration.py index 5fb1748032..3be2db4b47 100644 --- a/dimos/perception/object_scene_registration.py +++ b/dimos/perception/object_scene_registration.py @@ -354,8 +354,3 @@ def _process_3d_detections( aggregated_pc = aggregate_pointclouds(objects_for_pc) self.pointcloud.publish(aggregated_pc) return - - -object_scene_registration_module = ObjectSceneRegistrationModule.blueprint - -__all__ = ["ObjectSceneRegistrationModule", "object_scene_registration_module"] diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 6afc5e0814..a8970c61d8 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -631,8 +631,3 @@ def _get_depth_from_bbox(self, bbox: list[int], depth_frame: np.ndarray) -> floa return depth_25th_percentile return None - - -object_tracking = ObjectTracking.blueprint - -__all__ = ["ObjectTracking", "object_tracking"] diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index fe6d7d50e0..13a3c8e289 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -583,8 +583,3 @@ def deploy( # type: ignore[no-untyped-def] spatial_memory.color_image.connect(camera.color_image) spatial_memory.start() return spatial_memory - - -spatial_memory = SpatialMemory.blueprint - -__all__ = ["SpatialMemory", "deploy", "spatial_memory"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 5fd61891c2..6d77effe68 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -97,56 +97,78 @@ all_modules = { "agent": "dimos.agents.agent", "arm-teleop-module": "dimos.teleop.quest.quest_extensions", + "b-box-navigation-module": "dimos.navigation.bbox_navigation", + "b1-connection-module": "dimos.robot.unitree.b1.connection", "camera-module": "dimos.hardware.sensors.camera.module", "cartesian-motion-controller": "dimos.manipulation.control.servo_control.cartesian_motion_controller", "control-coordinator": "dimos.control.coordinator", "cost-mapper": "dimos.mapping.costmapper", "demo-calculator-skill": "dimos.agents.skills.demo_calculator_skill", "demo-robot": "dimos.agents.skills.demo_robot", - "detection-db-module": "dimos.perception.detection.moduleDB", - "detection3d-module": "dimos.perception.detection.module3D", - "fastlio2-module": "dimos.hardware.sensors.lidar.fastlio2.module", + "detection2-d-module": "dimos.perception.detection.module2D", + "detection3-d-module": "dimos.perception.detection.module3D", + "drone-camera-module": "dimos.robot.drone.camera_module", + "drone-connection-module": "dimos.robot.drone.connection_module", + "drone-tracking-module": "dimos.robot.drone.drone_tracking_module", + "embedding-memory": "dimos.memory.embedding", + "emitter-module": "dimos.utils.demo_image_encoding", + "fast-lio2": "dimos.hardware.sensors.lidar.fastlio2.module", "foxglove-bridge": "dimos.robot.foxglove_bridge", "g1-connection": "dimos.robot.unitree.g1.connection", + "g1-connection-base": "dimos.robot.unitree.g1.connection", "g1-sim-connection": "dimos.robot.unitree.g1.sim", - "g1-skills": "dimos.robot.unitree.g1.skill_container", "go2-connection": "dimos.robot.unitree.go2.connection", "go2-fleet-connection": "dimos.robot.unitree.go2.fleet_connection", - "google-maps-skill": "dimos.agents.skills.google_maps_skill_container", - "gps-nav-skill": "dimos.agents.skills.gps_nav_skill", + "google-maps-skill-container": "dimos.agents.skills.google_maps_skill_container", + "gps-nav-skill-container": "dimos.agents.skills.gps_nav_skill", + "grasp-gen-module": "dimos.manipulation.grasping.graspgen_module", "grasping-module": "dimos.manipulation.grasping.grasping", + "gstreamer-camera-module": "dimos.hardware.sensors.camera.gstreamer.gstreamer_camera", "joint-trajectory-controller": "dimos.manipulation.control.trajectory_controller.joint_trajectory_controller", + "joystick-module": "dimos.robot.unitree.b1.joystick_module", "keyboard-teleop": "dimos.robot.unitree.keyboard_teleop", "keyboard-teleop-module": "dimos.teleop.keyboard.keyboard_teleop_module", "manipulation-module": "dimos.manipulation.manipulation_module", - "mapper": "dimos.robot.unitree.type.map", + "map": "dimos.robot.unitree.type.map", "mcp-client": "dimos.agents.mcp.mcp_client", - "mid360-module": "dimos.hardware.sensors.lidar.livox.module", - "navigation-skill": "dimos.agents.skills.navigation", + "mcp-server": "dimos.agents.mcp.mcp_server", + "mock-b1-connection-module": "dimos.robot.unitree.b1.connection", + "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts", + "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts", + "navigation-module": "dimos.robot.unitree.rosnav", + "navigation-skill-container": "dimos.agents.skills.navigation", + "object-db-module": "dimos.perception.detection.moduleDB", "object-scene-registration-module": "dimos.perception.object_scene_registration", + "object-tracker2-d": "dimos.perception.object_tracker_2d", + "object-tracker3-d": "dimos.perception.object_tracker_3d", "object-tracking": "dimos.perception.object_tracker", "osm-skill": "dimos.agents.skills.osm", - "person-follow-skill": "dimos.agents.skills.person_follow", - "person-tracker-module": "dimos.perception.detection.person_tracker", + "patrolling-module": "dimos.navigation.patrolling.module", + "perceive-loop-skill": "dimos.perception.perceive_loop_skill", + "person-follow-skill-container": "dimos.agents.skills.person_follow", + "person-tracker": "dimos.perception.detection.person_tracker", "phone-teleop-module": "dimos.teleop.phone.phone_teleop_module", "pick-and-place-module": "dimos.manipulation.pick_and_place_module", "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module", - "realsense-camera": "dimos.hardware.sensors.camera.realsense.camera", + "real-sense-camera": "dimos.hardware.sensors.camera.realsense.camera", + "receiver-module": "dimos.utils.demo_image_encoding", + "reid-module": "dimos.perception.detection.reid.module", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module", - "rerun-bridge": "dimos.visualization.rerun.bridge", + "rerun-bridge-module": "dimos.visualization.rerun.bridge", "ros-nav": "dimos.navigation.rosnav", - "simple-phone-teleop-module": "dimos.teleop.phone.phone_extensions", - "simulation": "dimos.simulation.manipulators.sim_module", + "simple-phone-teleop": "dimos.teleop.phone.phone_extensions", + "simulation-module": "dimos.simulation.manipulators.sim_module", "spatial-memory": "dimos.perception.spatial_perception", "speak-skill": "dimos.agents.skills.speak_skill", "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory", "twist-teleop-module": "dimos.teleop.quest.quest_extensions", - "unitree-skills": "dimos.robot.unitree.unitree_skill_container", + "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container", + "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container", "vlm-agent": "dimos.agents.vlm_agent", "vlm-stream-tester": "dimos.agents.vlm_stream_tester", - "voxel-mapper": "dimos.mapping.voxels", + "voxel-grid-mapper": "dimos.mapping.voxels", "wavefront-frontier-explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", "web-input": "dimos.agents.web_human_input", - "websocket-vis": "dimos.web.websocket_vis.websocket_vis_module", + "websocket-vis-module": "dimos.web.websocket_vis.websocket_vis_module", "zed-camera": "dimos.hardware.sensors.camera.zed.camera", } diff --git a/dimos/robot/drone/blueprints/agentic/drone_agentic.py b/dimos/robot/drone/blueprints/agentic/drone_agentic.py index 5c0483dc24..f94af1ac5c 100644 --- a/dimos/robot/drone/blueprints/agentic/drone_agentic.py +++ b/dimos/robot/drone/blueprints/agentic/drone_agentic.py @@ -20,10 +20,10 @@ tracking, mapping skills, and an LLM agent. """ -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.agents.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.agents.skills.osm import OsmSkill -from dimos.agents.web_human_input import web_input +from dimos.agents.web_human_input import WebInput from dimos.core.blueprints import autoconnect from dimos.robot.drone.blueprints.basic.drone_basic import drone_basic from dimos.robot.drone.drone_tracking_module import DroneTrackingModule @@ -44,8 +44,8 @@ DroneTrackingModule.blueprint(outdoor=False), GoogleMapsSkillContainer.blueprint(), OsmSkill.blueprint(), - agent(system_prompt=DRONE_SYSTEM_PROMPT, model="gpt-4o"), - web_input(), + Agent.blueprint(system_prompt=DRONE_SYSTEM_PROMPT, model="gpt-4o"), + WebInput.blueprint(), ).remappings( [ (DroneTrackingModule, "video_input", "video"), diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index cfa311b80b..fbe6621ae1 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -23,7 +23,7 @@ from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _static_drone_body(rr: Any) -> list[Any]: @@ -68,13 +68,13 @@ def _drone_rerun_blueprint() -> Any: # Conditional visualization if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import foxglove_bridge + from dimos.robot.foxglove_bridge import FoxgloveBridge - _vis = foxglove_bridge() + _vis = FoxgloveBridge.blueprint() elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - _vis = rerun_bridge(viewer_mode=_resolve_viewer_mode(), **_rerun_config) + _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) else: _vis = autoconnect() @@ -92,7 +92,7 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), - websocket_vis(), + WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 9f0fc938e5..f5b6c20c97 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -105,9 +105,3 @@ def deploy( ) foxglove_bridge.start() return foxglove_bridge - - -foxglove_bridge = FoxgloveBridge.blueprint - - -__all__ = ["FoxgloveBridge", "deploy", "foxglove_bridge"] diff --git a/dimos/robot/manipulators/piper/blueprints.py b/dimos/robot/manipulators/piper/blueprints.py index ead27fd54b..54a1242537 100644 --- a/dimos/robot/manipulators/piper/blueprints.py +++ b/dimos/robot/manipulators/piper/blueprints.py @@ -23,16 +23,16 @@ """ from dimos.control.components import HardwareComponent, HardwareType, make_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.manipulation.manipulation_module import manipulation_module +from dimos.manipulation.manipulation_module import ManipulationModule from dimos.manipulation.planning.spec.config import RobotModelConfig from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module +from dimos.teleop.keyboard.keyboard_teleop_module import KeyboardTeleopModule from dimos.utils.data import LfsPath, get_data _PIPER_MODEL_PATH = LfsPath("piper_description/mujoco_model/piper_no_gripper_description.xml") @@ -40,8 +40,8 @@ # Piper 6-DOF mock sim + keyboard teleop + Drake visualization keyboard_teleop_piper = autoconnect( - keyboard_teleop_module(model_path=_PIPER_MODEL_PATH, ee_joint_id=6), - control_coordinator( + KeyboardTeleopModule.blueprint(model_path=_PIPER_MODEL_PATH, ee_joint_id=6), + ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, joint_state_frame_id="coordinator", @@ -64,7 +64,7 @@ ), ], ), - manipulation_module( + ManipulationModule.blueprint( robots=[ RobotModelConfig( name="arm", diff --git a/dimos/robot/manipulators/xarm/blueprints.py b/dimos/robot/manipulators/xarm/blueprints.py index e699057b44..ebfe9f2329 100644 --- a/dimos/robot/manipulators/xarm/blueprints.py +++ b/dimos/robot/manipulators/xarm/blueprints.py @@ -24,17 +24,17 @@ """ from dimos.control.components import HardwareComponent, HardwareType, make_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.manipulation.blueprints import ( _make_xarm6_config, _make_xarm7_config, ) -from dimos.manipulation.manipulation_module import manipulation_module +from dimos.manipulation.manipulation_module import ManipulationModule from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module +from dimos.teleop.keyboard.keyboard_teleop_module import KeyboardTeleopModule from dimos.utils.data import LfsPath _XARM6_MODEL_PATH = LfsPath("xarm_description/urdf/xarm6/xarm6.urdf") @@ -42,8 +42,8 @@ # XArm6 mock sim + keyboard teleop + Drake visualization keyboard_teleop_xarm6 = autoconnect( - keyboard_teleop_module(model_path=_XARM6_MODEL_PATH, ee_joint_id=6), - control_coordinator( + KeyboardTeleopModule.blueprint(model_path=_XARM6_MODEL_PATH, ee_joint_id=6), + ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, joint_state_frame_id="coordinator", @@ -66,7 +66,7 @@ ), ], ), - manipulation_module( + ManipulationModule.blueprint( robots=[_make_xarm6_config(name="arm", joint_prefix="arm_", add_gripper=False)], enable_viz=True, ), @@ -81,8 +81,8 @@ # XArm7 mock sim + keyboard teleop + Drake visualization keyboard_teleop_xarm7 = autoconnect( - keyboard_teleop_module(model_path=_XARM7_MODEL_PATH, ee_joint_id=7), - control_coordinator( + KeyboardTeleopModule.blueprint(model_path=_XARM7_MODEL_PATH, ee_joint_id=7), + ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, joint_state_frame_id="coordinator", @@ -105,7 +105,7 @@ ), ], ), - manipulation_module( + ManipulationModule.blueprint( robots=[_make_xarm7_config(name="arm", joint_prefix="arm_", add_gripper=False)], enable_viz=True, ), diff --git a/dimos/robot/test_all_blueprints.py b/dimos/robot/test_all_blueprints.py index 6c2d000ca8..78d0540fe1 100644 --- a/dimos/robot/test_all_blueprints.py +++ b/dimos/robot/test_all_blueprints.py @@ -22,6 +22,7 @@ OPTIONAL_DEPENDENCIES = {"pyrealsense2", "pyzed", "geometry_msgs", "turbojpeg"} OPTIONAL_ERROR_SUBSTRINGS = { "Unable to locate turbojpeg library automatically", + "ZED SDK not installed", } diff --git a/dimos/robot/test_all_blueprints_generation.py b/dimos/robot/test_all_blueprints_generation.py index e110df74e2..c4b9652e47 100644 --- a/dimos/robot/test_all_blueprints_generation.py +++ b/dimos/robot/test_all_blueprints_generation.py @@ -17,6 +17,7 @@ import difflib import os from pathlib import Path +import re import subprocess import pytest @@ -32,6 +33,7 @@ "dimos/core/test_blueprints.py", } BLUEPRINT_METHODS = {"transports", "global_config", "remappings", "requirements", "configurators"} +_EXCLUDED_MODULE_NAMES = {"Module", "ModuleBase"} def test_all_blueprints_is_current() -> None: @@ -76,22 +78,107 @@ def test_all_blueprints_is_current() -> None: ) +def _camel_to_snake(name: str) -> str: + """Convert CamelCase class name to snake_case.""" + s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", name) + s = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s) + return s.lower() + + +def _get_base_class_names(node: ast.ClassDef) -> list[str]: + """Extract base class names from a ClassDef, handling Name, Attribute, and Subscript.""" + names: list[str] = [] + for base in node.bases: + if isinstance(base, ast.Name): + names.append(base.id) + elif isinstance(base, ast.Attribute): + names.append(base.attr) + elif isinstance(base, ast.Subscript): + # Handle Generic[T] style: class Module(ModuleBase[ConfigT]) + v = base.value + if isinstance(v, ast.Name): + names.append(v.id) + elif isinstance(v, ast.Attribute): + names.append(v.attr) + return names + + +def _build_module_class_set(root: Path) -> set[str]: + """Build the set of all class names that are Module subclasses. + + Uses the same transitive-closure approach as dimos.core.test_modules: + start from {"Module", "ModuleBase"} and iteratively add any class whose + base appears in the known set until convergence. + """ + known: set[str] = {"Module", "ModuleBase"} + all_classes: list[tuple[str, list[str]]] = [] + + for path in sorted(root.rglob("*.py")): + if "__pycache__" in str(path): + continue + try: + tree = ast.parse(path.read_text("utf-8"), str(path)) + except Exception: + continue + for node in tree.body: + if isinstance(node, ast.ClassDef): + all_classes.append((node.name, _get_base_class_names(node))) + + changed = True + while changed: + changed = False + for name, bases in all_classes: + if name not in known and any(b in known for b in bases): + known.add(name) + changed = True + + return known + + +def _is_production_module_file(file_path: Path, root: Path) -> bool: + """Return True if this file should contribute to the all_modules registry. + + Excludes test helpers, deprecated code, and framework base classes in core/. + """ + rel = str(file_path.relative_to(root)) + stem = file_path.stem + return not ( + stem.startswith("test_") + or "_test_" in stem + or stem.endswith("_test") + or stem.startswith("fake_") + or stem.startswith("mock_") + or "deprecated" in rel + or "/testing/" in rel + or rel.startswith("core/") + ) + + def _scan_for_blueprints(root: Path) -> tuple[dict[str, str], dict[str, str]]: all_blueprints: dict[str, str] = {} all_modules: dict[str, str] = {} + module_classes = _build_module_class_set(root) + for file_path in sorted(_get_all_python_files(root)): module_name = _path_to_module_name(file_path, root) - blueprint_vars, module_vars = _find_blueprints_in_file(file_path) + blueprint_vars, module_vars = _find_blueprints_in_file(file_path, module_classes) for var_name in blueprint_vars: full_path = f"{module_name}:{var_name}" cli_name = var_name.replace("_", "-") all_blueprints[cli_name] = full_path - for var_name in module_vars: - cli_name = var_name.replace("_", "-") - all_modules[cli_name] = module_name + # Only register modules from production files (skip test, deprecated, core) + if _is_production_module_file(file_path, root): + for var_name in module_vars: + cli_name = var_name.replace("_", "-") + all_modules[cli_name] = module_name + + # Blueprints take priority when names collide (e.g. a pre-configured + # blueprint named "mid360" vs the raw Mid360 Module class). + for key in set(all_modules) & set(all_blueprints): + del all_modules[key] return all_blueprints, all_modules @@ -161,7 +248,9 @@ def _path_to_module_name(path: Path, root: Path) -> str: return ".".join(parts) -def _find_blueprints_in_file(file_path: Path) -> tuple[list[str], list[str]]: +def _find_blueprints_in_file( + file_path: Path, module_classes: set[str] | None = None +) -> tuple[list[str], list[str]]: blueprint_vars: list[str] = [] module_vars: list[str] = [] @@ -173,24 +262,26 @@ def _find_blueprints_in_file(file_path: Path) -> tuple[list[str], list[str]]: # Only look at top-level statements (direct children of the Module node) for node in tree.body: - if not isinstance(node, ast.Assign): - continue - - # Get the variable name(s) - for target in node.targets: - if not isinstance(target, ast.Name): - continue - var_name = target.id - - if var_name.startswith("_"): + if isinstance(node, ast.Assign): + # Get the variable name(s) + for target in node.targets: + if not isinstance(target, ast.Name): + continue + var_name = target.id + + if var_name.startswith("_"): + continue + + # Check if it's a blueprint (ModuleBlueprintSet instance) + if _is_autoconnect_call(node.value) or _ends_with_blueprint_method(node.value): + blueprint_vars.append(var_name) + + # Detect Module subclasses by checking base classes against the known set + elif isinstance(node, ast.ClassDef) and module_classes: + if node.name.startswith("_") or node.name in _EXCLUDED_MODULE_NAMES: continue - - # Check if it's a blueprint (ModuleBlueprintSet instance) - if _is_autoconnect_call(node.value) or _ends_with_blueprint_method(node.value): - blueprint_vars.append(var_name) - # Check if it's a module factory (SomeModule.blueprint) - elif _is_blueprint_factory(node.value): - module_vars.append(var_name) + if any(b in module_classes for b in _get_base_class_names(node)): + module_vars.append(_camel_to_snake(node.name)) return blueprint_vars, module_vars @@ -213,9 +304,3 @@ def _ends_with_blueprint_method(node: ast.expr) -> bool: if isinstance(func, ast.Attribute) and func.attr in BLUEPRINT_METHODS: return True return False - - -def _is_blueprint_factory(node: ast.expr) -> bool: - if isinstance(node, ast.Attribute): - return node.attr == "blueprint" - return False diff --git a/dimos/robot/unitree/g1/blueprints/agentic/_agentic_skills.py b/dimos/robot/unitree/g1/blueprints/agentic/_agentic_skills.py index 820f532570..834dd6d0a3 100644 --- a/dimos/robot/unitree/g1/blueprints/agentic/_agentic_skills.py +++ b/dimos/robot/unitree/g1/blueprints/agentic/_agentic_skills.py @@ -15,18 +15,18 @@ """Agentic skills used by higher-level G1 blueprints.""" -from dimos.agents.agent import agent -from dimos.agents.skills.navigation import navigation_skill -from dimos.agents.skills.speak_skill import speak_skill +from dimos.agents.agent import Agent +from dimos.agents.skills.navigation import NavigationSkillContainer +from dimos.agents.skills.speak_skill import SpeakSkill from dimos.core.blueprints import autoconnect -from dimos.robot.unitree.g1.skill_container import g1_skills +from dimos.robot.unitree.g1.skill_container import UnitreeG1SkillContainer from dimos.robot.unitree.g1.system_prompt import G1_SYSTEM_PROMPT _agentic_skills = autoconnect( - agent(system_prompt=G1_SYSTEM_PROMPT), - navigation_skill(), - speak_skill(), - g1_skills(), + Agent.blueprint(system_prompt=G1_SYSTEM_PROMPT), + NavigationSkillContainer.blueprint(), + SpeakSkill.blueprint(), + UnitreeG1SkillContainer.blueprint(), ) __all__ = ["_agentic_skills"] diff --git a/dimos/robot/unitree/g1/blueprints/agentic/unitree_g1_full.py b/dimos/robot/unitree/g1/blueprints/agentic/unitree_g1_full.py index 7f826f2eec..b3c6dfabaa 100644 --- a/dimos/robot/unitree/g1/blueprints/agentic/unitree_g1_full.py +++ b/dimos/robot/unitree/g1/blueprints/agentic/unitree_g1_full.py @@ -18,12 +18,12 @@ from dimos.core.blueprints import autoconnect from dimos.robot.unitree.g1.blueprints.agentic._agentic_skills import _agentic_skills from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1_shm import unitree_g1_shm -from dimos.robot.unitree.keyboard_teleop import keyboard_teleop +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop unitree_g1_full = autoconnect( unitree_g1_shm, _agentic_skills, - keyboard_teleop(), + KeyboardTeleop.blueprint(), ) __all__ = ["unitree_g1_full"] diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py index 1fb591e895..fd392e4aa2 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py @@ -16,16 +16,16 @@ """Basic G1 stack: base sensors plus real robot connection and ROS nav.""" from dimos.core.blueprints import autoconnect -from dimos.navigation.rosnav import ros_nav +from dimos.navigation.rosnav import ROSNav from dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav import ( uintree_g1_primitive_no_nav, ) -from dimos.robot.unitree.g1.connection import g1_connection +from dimos.robot.unitree.g1.connection import G1Connection unitree_g1_basic = autoconnect( uintree_g1_primitive_no_nav, - g1_connection(), - ros_nav(), + G1Connection.blueprint(), + ROSNav.blueprint(), ) __all__ = ["unitree_g1_basic"] diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py index 603a9535ee..3294da1772 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py @@ -16,16 +16,16 @@ """Basic G1 sim stack: base sensors plus sim connection and planner.""" from dimos.core.blueprints import autoconnect -from dimos.navigation.replanning_a_star.module import replanning_a_star_planner +from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner from dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav import ( uintree_g1_primitive_no_nav, ) -from dimos.robot.unitree.g1.sim import g1_sim_connection +from dimos.robot.unitree.g1.sim import G1SimConnection unitree_g1_basic_sim = autoconnect( uintree_g1_primitive_no_nav, - g1_sim_connection(), - replanning_a_star_planner(), + G1SimConnection.blueprint(), + ReplanningAStarPlanner.blueprint(), ) __all__ = ["unitree_g1_basic_sim"] diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_joystick.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_joystick.py index 0242556189..4dcc6a8329 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_joystick.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_joystick.py @@ -17,11 +17,11 @@ from dimos.core.blueprints import autoconnect from dimos.robot.unitree.g1.blueprints.basic.unitree_g1_basic import unitree_g1_basic -from dimos.robot.unitree.keyboard_teleop import keyboard_teleop +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop unitree_g1_joystick = autoconnect( unitree_g1_basic, - keyboard_teleop(), # Pygame-based joystick control + KeyboardTeleop.blueprint(), # Pygame-based joystick control ) __all__ = ["unitree_g1_joystick"] diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/_perception_and_memory.py b/dimos/robot/unitree/g1/blueprints/perceptive/_perception_and_memory.py index 241fcb32a8..672a990f94 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/_perception_and_memory.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/_perception_and_memory.py @@ -16,12 +16,12 @@ """Perception and memory modules used by higher-level G1 blueprints.""" from dimos.core.blueprints import autoconnect -from dimos.perception.object_tracker import object_tracking -from dimos.perception.spatial_perception import spatial_memory +from dimos.perception.object_tracker import ObjectTracking +from dimos.perception.spatial_perception import SpatialMemory _perception_and_memory = autoconnect( - spatial_memory(), - object_tracking(frame_id="camera_link"), + SpatialMemory.blueprint(), + ObjectTracking.blueprint(frame_id="camera_link"), ) __all__ = ["_perception_and_memory"] diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py index 18884bd7af..9bd82f0f6f 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py @@ -28,9 +28,9 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector -from dimos.perception.detection.module3D import Detection3DModule, detection3d_module -from dimos.perception.detection.moduleDB import ObjectDBModule, detection_db_module -from dimos.perception.detection.person_tracker import PersonTracker, person_tracker_module +from dimos.perception.detection.module3D import Detection3DModule +from dimos.perception.detection.moduleDB import ObjectDBModule +from dimos.perception.detection.person_tracker import PersonTracker from dimos.robot.unitree.g1.blueprints.basic.unitree_g1_basic import unitree_g1_basic @@ -42,15 +42,15 @@ def _person_only(det: Any) -> bool: autoconnect( unitree_g1_basic, # Person detection modules with YOLO - detection3d_module( + Detection3DModule.blueprint( camera_info=zed.CameraInfo.SingleWebcam, detector=YoloPersonDetector, ), - detection_db_module( + ObjectDBModule.blueprint( camera_info=zed.CameraInfo.SingleWebcam, filter=_person_only, # Filter for person class only ), - person_tracker_module( + PersonTracker.blueprint( cameraInfo=zed.CameraInfo.SingleWebcam, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index be67194b62..5b127fb697 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -19,7 +19,7 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 unitree_g1_shm = autoconnect( @@ -30,7 +30,7 @@ ), } ), - foxglove_bridge( + FoxgloveBridge.blueprint( shm_channels=[ "/color_image#sensor_msgs.Image", ] diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index 242fcaf38f..c3da9521c5 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -22,11 +22,11 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.module import camera_module # type: ignore[attr-defined] +from dimos.hardware.sensors.camera.module import CameraModule # type: ignore[attr-defined] from dimos.hardware.sensors.camera.webcam import Webcam from dimos.hardware.sensors.camera.zed import compat as zed -from dimos.mapping.costmapper import cost_mapper -from dimos.mapping.voxels import voxel_mapper +from dimos.mapping.costmapper import CostMapper +from dimos.mapping.voxels import VoxelGridMapper from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform @@ -38,10 +38,10 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.std_msgs.Bool import Bool from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( - wavefront_frontier_explorer, + WavefrontFrontierExplorer, ) from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _convert_camera_info(camera_info: Any) -> Any: @@ -102,13 +102,15 @@ def _g1_rerun_blueprint() -> Any: } if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import foxglove_bridge + from dimos.robot.foxglove_bridge import FoxgloveBridge - _with_vis = autoconnect(foxglove_bridge()) + _with_vis = autoconnect(FoxgloveBridge.blueprint()) elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - _with_vis = autoconnect(rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config)) + _with_vis = autoconnect( + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) + ) else: _with_vis = autoconnect() @@ -124,7 +126,7 @@ def _create_webcam() -> Webcam: _camera = ( autoconnect( - camera_module( + CameraModule.blueprint( transform=Transform( translation=Vector3(0.05, 0.0, 0.6), # height of camera on G1 robot rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), @@ -142,11 +144,11 @@ def _create_webcam() -> Webcam: autoconnect( _with_vis, _camera, - voxel_mapper(voxel_size=0.1), - cost_mapper(), - wavefront_frontier_explorer(), + VoxelGridMapper.blueprint(voxel_size=0.1), + CostMapper.blueprint(), + WavefrontFrontierExplorer.blueprint(), # Visualization - websocket_vis(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index 1f3788de98..bc2ca7d3d9 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -112,14 +112,8 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: return self.connection.publish_request(topic, data) # type: ignore[no-any-return] -g1_connection = G1Connection.blueprint - - def deploy(dimos: ModuleCoordinator, ip: str, local_planner: LocalPlanner) -> "ModuleProxy": connection = dimos.deploy(G1Connection, ip=ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection - - -__all__ = ["G1Connection", "G1ConnectionBase", "deploy", "g1_connection"] diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index 206a689284..22fc33a978 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -148,9 +148,3 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: logger.info(f"Publishing request to topic: {topic} with data: {data}") assert self.connection is not None return self.connection.publish_request(topic, data) - - -g1_sim_connection = G1SimConnection.blueprint - - -__all__ = ["G1SimConnection", "g1_sim_connection"] diff --git a/dimos/robot/unitree/g1/skill_container.py b/dimos/robot/unitree/g1/skill_container.py index b1342ca96d..ffe8dae5f0 100644 --- a/dimos/robot/unitree/g1/skill_container.py +++ b/dimos/robot/unitree/g1/skill_container.py @@ -158,7 +158,3 @@ def _execute_g1_command( {_mode_commands} """ - -g1_skills = UnitreeG1SkillContainer.blueprint - -__all__ = ["UnitreeG1SkillContainer", "g1_skills"] diff --git a/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py b/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py index 817d5e3a7d..874feddd35 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py @@ -13,20 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.skills.navigation import navigation_skill -from dimos.agents.skills.person_follow import person_follow_skill -from dimos.agents.skills.speak_skill import speak_skill -from dimos.agents.web_human_input import web_input +from dimos.agents.skills.navigation import NavigationSkillContainer +from dimos.agents.skills.person_follow import PersonFollowSkillContainer +from dimos.agents.skills.speak_skill import SpeakSkill +from dimos.agents.web_human_input import WebInput from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.robot.unitree.unitree_skill_container import unitree_skills +from dimos.robot.unitree.unitree_skill_container import UnitreeSkillContainer _common_agentic = autoconnect( - navigation_skill(), - person_follow_skill(camera_info=GO2Connection.camera_info_static), - unitree_skills(), - web_input(), - speak_skill(), + NavigationSkillContainer.blueprint(), + PersonFollowSkillContainer.blueprint(camera_info=GO2Connection.camera_info_static), + UnitreeSkillContainer.blueprint(), + WebInput.blueprint(), + SpeakSkill.blueprint(), ) __all__ = ["_common_agentic"] diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic.py index 2fb1a4cb74..cb0d523fbd 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.agentic._common_agentic import _common_agentic from dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial import unitree_go2_spatial unitree_go2_agentic = autoconnect( unitree_go2_spatial, - agent(), + Agent.blueprint(), _common_agentic, ) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_huggingface.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_huggingface.py index 1c998b7495..75a2245a99 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_huggingface.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_huggingface.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.agentic._common_agentic import _common_agentic from dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial import unitree_go2_spatial unitree_go2_agentic_huggingface = autoconnect( unitree_go2_spatial, - agent(model="huggingface:Qwen/Qwen2.5-1.5B-Instruct"), + Agent.blueprint(model="huggingface:Qwen/Qwen2.5-1.5B-Instruct"), _common_agentic, ) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py index e75b31e511..0663119adb 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_mcp.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.mcp.mcp_client import mcp_client +from dimos.agents.mcp.mcp_client import McpClient from dimos.agents.mcp.mcp_server import McpServer from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.agentic._common_agentic import _common_agentic @@ -22,7 +22,7 @@ unitree_go2_agentic_mcp = autoconnect( unitree_go2_spatial, McpServer.blueprint(), - mcp_client(), + McpClient.blueprint(), _common_agentic, ) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_ollama.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_ollama.py index 6a518ad831..334ea52d35 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_ollama.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_agentic_ollama.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.agent import agent +from dimos.agents.agent import Agent from dimos.agents.ollama_agent import ollama_installed from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.agentic._common_agentic import _common_agentic @@ -21,7 +21,7 @@ unitree_go2_agentic_ollama = autoconnect( unitree_go2_spatial, - agent(model="ollama:qwen3:8b"), + Agent.blueprint(model="ollama:qwen3:8b"), _common_agentic, ).requirements( ollama_installed, diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py index 24ab47ad3b..733672bc78 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py @@ -16,8 +16,8 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.perception.experimental.temporal_memory.temporal_memory import ( + TemporalMemory, TemporalMemoryConfig, - temporal_memory, ) from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic @@ -25,7 +25,7 @@ # AFTER global_config.update() has applied CLI flags like --new-memory. unitree_go2_temporal_memory = autoconnect( unitree_go2_agentic, - temporal_memory(config=TemporalMemoryConfig(new_memory=global_config.new_memory)), + TemporalMemory.blueprint(config=TemporalMemoryConfig(new_memory=global_config.new_memory)), ) __all__ = ["unitree_go2_temporal_memory"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 3325290bf7..a0d1e6a7ae 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -24,8 +24,8 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.connection import go2_connection -from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -108,17 +108,18 @@ def _go2_rerun_blueprint() -> Any: if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import foxglove_bridge + from dimos.robot.foxglove_bridge import FoxgloveBridge with_vis = autoconnect( _transports_base, - foxglove_bridge(shm_channels=["/color_image#sensor_msgs.Image"]), + FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), ) elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode with_vis = autoconnect( - _transports_base, rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config) + _transports_base, + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) else: with_vis = _transports_base @@ -126,8 +127,8 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_basic = ( autoconnect( with_vis, - go2_connection(), - websocket_vis(), + GO2Connection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 908444b2fd..1c55f3e93c 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -23,14 +23,14 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis -from dimos.robot.unitree.go2.fleet_connection import go2_fleet_connection -from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( with_vis, - go2_fleet_connection(), - websocket_vis(), + Go2FleetConnection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 10dd290e2d..194aff60ca 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -14,21 +14,21 @@ # limitations under the License. from dimos.core.blueprints import autoconnect -from dimos.mapping.costmapper import cost_mapper -from dimos.mapping.voxels import voxel_mapper +from dimos.mapping.costmapper import CostMapper +from dimos.mapping.voxels import VoxelGridMapper from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( - wavefront_frontier_explorer, + WavefrontFrontierExplorer, ) from dimos.navigation.patrolling.module import PatrollingModule -from dimos.navigation.replanning_a_star.module import replanning_a_star_planner +from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2 = autoconnect( unitree_go2_basic, - voxel_mapper(voxel_size=0.1), - cost_mapper(), - replanning_a_star_planner(), - wavefront_frontier_explorer(), + VoxelGridMapper.blueprint(voxel_size=0.1), + CostMapper.blueprint(), + ReplanningAStarPlanner.blueprint(), + WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), ).global_config(n_workers=7, robot_model="unitree_go2") diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py index a9bb7729ae..ae76e260cf 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py @@ -23,14 +23,14 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray -from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.detection.module3D import Detection3DModule from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 from dimos.robot.unitree.go2.connection import GO2Connection unitree_go2_detection = ( autoconnect( unitree_go2, - detection3d_module( + Detection3DModule.blueprint( camera_info=GO2Connection.camera_info_static, ), ) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py index 63ffab53c8..840458d998 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py @@ -15,12 +15,12 @@ from dimos.core.blueprints import autoconnect from dimos.perception.perceive_loop_skill import PerceiveLoopSkill -from dimos.perception.spatial_perception import spatial_memory +from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 unitree_go2_spatial = autoconnect( unitree_go2, - spatial_memory(), + SpatialMemory.blueprint(), PerceiveLoopSkill.blueprint(), ).global_config(n_workers=8) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_vlm_stream_test.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_vlm_stream_test.py index 194d3973c6..60c4c9ce43 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_vlm_stream_test.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_vlm_stream_test.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.vlm_agent import vlm_agent -from dimos.agents.vlm_stream_tester import vlm_stream_tester +from dimos.agents.vlm_agent import VLMAgent +from dimos.agents.vlm_stream_tester import VlmStreamTester from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2_vlm_stream_test = autoconnect( unitree_go2_basic, - vlm_agent(), - vlm_stream_tester(), + VLMAgent.blueprint(), + VlmStreamTester.blueprint(), ) __all__ = ["unitree_go2_vlm_stream_test"] diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 38da7fb439..db3ecb40fc 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -337,9 +337,6 @@ def observe(self) -> Image | None: return self._latest_video_frame -go2_connection = GO2Connection.blueprint - - def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy": from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE @@ -358,6 +355,3 @@ def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy" connection.start() return connection - - -__all__ = ["GO2Connection", "deploy", "go2_connection", "make_connection"] diff --git a/dimos/robot/unitree/go2/fleet_connection.py b/dimos/robot/unitree/go2/fleet_connection.py index f0e904648a..58fa854297 100644 --- a/dimos/robot/unitree/go2/fleet_connection.py +++ b/dimos/robot/unitree/go2/fleet_connection.py @@ -142,9 +142,3 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: except Exception as e: logger.error(f"Fleet publish_request failed: {e}") return self.connection.publish_request(topic, data) - - -go2_fleet_connection = Go2FleetConnection.blueprint - - -__all__ = ["Go2FleetConnection", "go2_fleet_connection"] diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 86885bc446..a05bd53d50 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -202,8 +202,3 @@ def _update_display(self, twist: Twist) -> None: y_pos += 25 pygame.display.flip() - - -keyboard_teleop = KeyboardTeleop.blueprint - -__all__ = ["KeyboardTeleop", "keyboard_teleop"] diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index da45c003f7..4ec9419c53 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -112,9 +112,6 @@ def _publish(self, _: Any) -> None: self.global_costmap.publish(occupancygrid) -mapper = Map.blueprint - - def deploy(dimos: ModuleCoordinator, connection: Go2ConnectionProtocol): # type: ignore[no-untyped-def] mapper = dimos.deploy(Map, global_publish_interval=1.0) # type: ignore[attr-defined] mapper.global_map.transport = LCMTransport("/global_map", PointCloud2) @@ -122,6 +119,3 @@ def deploy(dimos: ModuleCoordinator, connection: Go2ConnectionProtocol): # type mapper.lidar.connect(connection.pointcloud) # type: ignore[attr-defined] mapper.start() return mapper - - -__all__ = ["Map", "mapper"] diff --git a/dimos/robot/unitree/unitree_skill_container.py b/dimos/robot/unitree/unitree_skill_container.py index a79c061567..f536d8b45c 100644 --- a/dimos/robot/unitree/unitree_skill_container.py +++ b/dimos/robot/unitree/unitree_skill_container.py @@ -334,8 +334,3 @@ def execute_sport_command(self, command_name: str) -> str: {_commands} """ - - -unitree_skills = UnitreeSkillContainer.blueprint - -__all__ = ["UnitreeSkillContainer", "unitree_skills"] diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 5e873ba634..66a2b5d888 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -232,12 +232,3 @@ def _resolve_joint_names(self, dof: int) -> list[str]: if len(names) >= dof: return list(names[:dof]) return [f"{self._joint_prefix}{i + 1}" for i in range(dof)] - - -simulation = SimulationModule.blueprint - -__all__ = [ - "SimulationModule", - "SimulationModuleConfig", - "simulation", -] diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py index 494b97ccbf..2a8dd2d029 100644 --- a/dimos/simulation/sim_blueprints.py +++ b/dimos/simulation/sim_blueprints.py @@ -18,10 +18,10 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.msgs.sensor_msgs.RobotState import RobotState from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory -from dimos.simulation.manipulators.sim_module import simulation +from dimos.simulation.manipulators.sim_module import SimulationModule from dimos.utils.data import LfsPath -xarm7_trajectory_sim = simulation( +xarm7_trajectory_sim = SimulationModule.blueprint( engine="mujoco", config_path=LfsPath("xarm7/scene.xml"), headless=True, @@ -38,7 +38,6 @@ __all__ = [ - "simulation", "xarm7_trajectory_sim", ] diff --git a/dimos/teleop/keyboard/keyboard_teleop_module.py b/dimos/teleop/keyboard/keyboard_teleop_module.py index a90dc3cf44..cae1c503cd 100644 --- a/dimos/teleop/keyboard/keyboard_teleop_module.py +++ b/dimos/teleop/keyboard/keyboard_teleop_module.py @@ -213,6 +213,3 @@ def _pygame_loop(self) -> None: clock.tick(50) pygame.quit() - - -keyboard_teleop_module = KeyboardTeleopModule.blueprint diff --git a/dimos/teleop/phone/blueprints.py b/dimos/teleop/phone/blueprints.py index 86e1154d92..908944034e 100644 --- a/dimos/teleop/phone/blueprints.py +++ b/dimos/teleop/phone/blueprints.py @@ -16,22 +16,22 @@ from dimos.core.blueprints import autoconnect from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet import unitree_go2_fleet -from dimos.teleop.phone.phone_extensions import simple_phone_teleop_module +from dimos.teleop.phone.phone_extensions import SimplePhoneTeleop # Simple phone teleop (mobile base axis filtering + cmd_vel output) teleop_phone = autoconnect( - simple_phone_teleop_module(), + SimplePhoneTeleop.blueprint(), ) # Phone teleop wired to Unitree Go2 teleop_phone_go2 = autoconnect( - simple_phone_teleop_module(), + SimplePhoneTeleop.blueprint(), unitree_go2_basic, ) # Phone teleop wired to Go2 fleet — twist commands sent to all robots teleop_phone_go2_fleet = autoconnect( - simple_phone_teleop_module(), + SimplePhoneTeleop.blueprint(), unitree_go2_fleet, ) diff --git a/dimos/teleop/phone/phone_extensions.py b/dimos/teleop/phone/phone_extensions.py index c5cdc1fc80..bd3e6cac7d 100644 --- a/dimos/teleop/phone/phone_extensions.py +++ b/dimos/teleop/phone/phone_extensions.py @@ -43,11 +43,3 @@ def _publish_msg(self, output_msg: TwistStamped) -> None: angular=Vector3(x=0.0, y=0.0, z=output_msg.linear.z), ) ) - - -simple_phone_teleop_module = SimplePhoneTeleop.blueprint - -__all__ = [ - "SimplePhoneTeleop", - "simple_phone_teleop_module", -] diff --git a/dimos/teleop/phone/phone_teleop_module.py b/dimos/teleop/phone/phone_teleop_module.py index 3f32063cce..35bf02fe1e 100644 --- a/dimos/teleop/phone/phone_teleop_module.py +++ b/dimos/teleop/phone/phone_teleop_module.py @@ -280,12 +280,3 @@ def _publish_msg(self, output_msg: TwistStamped) -> None: Override to customize output (e.g., apply limits, remap axes). """ self.twist_output.publish(output_msg) - - -phone_teleop_module = PhoneTeleopModule.blueprint - -__all__ = [ - "PhoneTeleopConfig", - "PhoneTeleopModule", - "phone_teleop_module", -] diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 71e16c2da8..6855ab62ca 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -23,14 +23,14 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.teleop.quest.quest_extensions import arm_teleop_module +from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import rerun_bridge +from dimos.visualization.rerun.bridge import RerunBridgeModule # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( - arm_teleop_module(), - rerun_bridge(), + ArmTeleopModule.blueprint(), + RerunBridgeModule.blueprint(), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), @@ -42,7 +42,7 @@ # Single XArm7 teleop: right controller -> xarm7 teleop_quest_xarm7 = autoconnect( - arm_teleop_module(task_names={"right": "teleop_xarm"}), + ArmTeleopModule.blueprint(task_names={"right": "teleop_xarm"}), coordinator_teleop_xarm7, ).transports( { @@ -56,7 +56,7 @@ # Single Piper teleop: left controller -> piper arm teleop_quest_piper = autoconnect( - arm_teleop_module(task_names={"left": "teleop_piper"}), + ArmTeleopModule.blueprint(task_names={"left": "teleop_piper"}), coordinator_teleop_piper, ).transports( { @@ -70,7 +70,7 @@ # Dual arm teleop: right -> piper, left -> xarm6 (TeleopIK) teleop_quest_dual = autoconnect( - arm_teleop_module(task_names={"right": "teleop_piper", "left": "teleop_xarm"}), + ArmTeleopModule.blueprint(task_names={"right": "teleop_piper", "left": "teleop_xarm"}), coordinator_teleop_dual, ).transports( { diff --git a/dimos/teleop/quest/quest_extensions.py b/dimos/teleop/quest/quest_extensions.py index 674fc36f1e..eb7a453929 100644 --- a/dimos/teleop/quest/quest_extensions.py +++ b/dimos/teleop/quest/quest_extensions.py @@ -131,16 +131,3 @@ def _publish_button_state( right=right.trigger if right is not None else 0.0, ) self.buttons.publish(buttons) - - -# Module blueprints for easy instantiation -twist_teleop_module = TwistTeleopModule.blueprint -arm_teleop_module = ArmTeleopModule.blueprint - -__all__ = [ - "ArmTeleopConfig", - "ArmTeleopModule", - "TwistTeleopModule", - "arm_teleop_module", - "twist_teleop_module", -] diff --git a/dimos/teleop/quest/quest_teleop_module.py b/dimos/teleop/quest/quest_teleop_module.py index 5868aab620..28199ff084 100644 --- a/dimos/teleop/quest/quest_teleop_module.py +++ b/dimos/teleop/quest/quest_teleop_module.py @@ -379,14 +379,3 @@ def _publish_button_state( """ buttons = Buttons.from_controllers(left, right) self.buttons.publish(buttons) - - -quest_teleop_module = QuestTeleopModule.blueprint - -__all__ = [ - "Hand", - "QuestTeleopConfig", - "QuestTeleopModule", - "QuestTeleopStatus", - "quest_teleop_module", -] diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 12f998d96d..8b1cda443c 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -390,7 +390,3 @@ def cli( if __name__ == "__main__": app() - -# you don't need to include this in your blueprint if you are not creating a -# custom rerun configuration for your deployment, you can also run rerun-bridge standalone -rerun_bridge = RerunBridgeModule.blueprint diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 5514144570..685ca2b1ee 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -402,8 +402,3 @@ def _process_costmap(self, costmap: OccupancyGrid) -> dict[str, Any]: def _emit(self, event: str, data: Any) -> None: if self._broadcast_loop and not self._broadcast_loop.is_closed(): asyncio.run_coroutine_threadsafe(self.sio.emit(event, data), self._broadcast_loop) - - -websocket_vis = WebsocketVisModule.blueprint - -__all__ = ["WebsocketVisModule", "websocket_vis"] diff --git a/docs/capabilities/manipulation/adding_a_custom_arm.md b/docs/capabilities/manipulation/adding_a_custom_arm.md index 3e931a7f73..08a08b6144 100644 --- a/docs/capabilities/manipulation/adding_a_custom_arm.md +++ b/docs/capabilities/manipulation/adding_a_custom_arm.md @@ -438,13 +438,13 @@ from __future__ import annotations from pathlib import Path from dimos.control.components import HardwareComponent, HardwareType, make_joints -from dimos.control.coordinator import TaskConfig, control_coordinator +from dimos.control.coordinator import ControlCoordinator, TaskConfig from dimos.core.transport import LCMTransport from dimos.msgs.sensor_msgs import JointState # YourArm (6-DOF) — real hardware -coordinator_yourarm = control_coordinator( +coordinator_yourarm = ControlCoordinator.blueprint( tick_rate=100.0, # Control loop frequency (Hz) publish_joint_state=True, # Publish aggregated joint state joint_state_frame_id="coordinator", diff --git a/pyproject.toml b/pyproject.toml index 1535885edf..a60c3d308a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -437,7 +437,7 @@ env = [ "GOOGLE_MAPS_API_KEY=AIzafake_google_key", "PYTHONWARNINGS=ignore:cupyx.jit.rawkernel is experimental:FutureWarning", ] -addopts = "-v -r a -p no:warnings -p no:launch_testing -p no:launch_ros --import-mode=importlib --color=yes -m 'not (tool or slow or mujoco)'" +addopts = "-s -v -r a -p no:warnings -p no:launch_testing -p no:launch_ros --import-mode=importlib --color=yes -m 'not (tool or slow or mujoco)'" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" From 1924b613acc0db72e985eb89f64adc74a315fc40 Mon Sep 17 00:00:00 2001 From: Mustafa Bhadsorawala <39084056+mustafab0@users.noreply.github.com> Date: Sat, 21 Mar 2026 13:20:24 -0700 Subject: [PATCH 25/42] Feature: go2 webrtc TwistBase adapter for control coordinator (#1362) --- dimos/control/coordinator.py | 2 + dimos/hardware/drive_trains/registry.py | 20 ++- .../drive_trains/transport/adapter.py | 157 ++++++++++++++++++ dimos/robot/all_blueprints.py | 2 + dimos/robot/unitree/connection.py | 19 +-- .../basic/unitree_go2_coordinator.py | 76 +++++++++ .../unitree_go2_webrtc_keyboard_teleop.py | 36 ++++ dimos/simulation/mujoco/policy.py | 7 +- 8 files changed, 300 insertions(+), 19 deletions(-) create mode 100644 dimos/hardware/drive_trains/transport/adapter.py create mode 100644 dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py create mode 100644 dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 9f3264f85c..78184d8272 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -237,6 +237,7 @@ def _create_adapter(self, component: HardwareComponent) -> ManipulatorAdapter: component.adapter_type, dof=len(component.joints), address=component.address, + hardware_id=component.hardware_id, ) def _create_twist_base_adapter(self, component: HardwareComponent) -> TwistBaseAdapter: @@ -247,6 +248,7 @@ def _create_twist_base_adapter(self, component: HardwareComponent) -> TwistBaseA component.adapter_type, dof=len(component.joints), address=component.address, + hardware_id=component.hardware_id, ) def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: diff --git a/dimos/hardware/drive_trains/registry.py b/dimos/hardware/drive_trains/registry.py index 0a513d2bd4..435d61c73f 100644 --- a/dimos/hardware/drive_trains/registry.py +++ b/dimos/hardware/drive_trains/registry.py @@ -30,9 +30,10 @@ from __future__ import annotations +from collections.abc import Callable import importlib import logging -import pkgutil +import os from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -45,10 +46,11 @@ class TwistBaseAdapterRegistry: """Registry for twist base adapters with auto-discovery.""" def __init__(self) -> None: - self._adapters: dict[str, type[TwistBaseAdapter]] = {} + self._adapters: dict[str, type[TwistBaseAdapter] | Callable[..., TwistBaseAdapter]] = {} - def register(self, name: str, cls: type[TwistBaseAdapter]) -> None: - """Register an adapter class.""" + def register( + self, name: str, cls: type[TwistBaseAdapter] | Callable[..., TwistBaseAdapter] + ) -> None: self._adapters[name.lower()] = cls def create(self, name: str, **kwargs: Any) -> TwistBaseAdapter: @@ -81,15 +83,17 @@ def discover(self) -> None: """ import dimos.hardware.drive_trains as pkg - for _, name, ispkg in pkgutil.iter_modules(pkg.__path__): - if not ispkg: + pkg_dir = pkg.__path__[0] + for entry in sorted(os.listdir(pkg_dir)): + entry_path = os.path.join(pkg_dir, entry) + if not os.path.isdir(entry_path) or entry.startswith(("_", ".")): continue try: - module = importlib.import_module(f"dimos.hardware.drive_trains.{name}.adapter") + module = importlib.import_module(f"dimos.hardware.drive_trains.{entry}.adapter") if hasattr(module, "register"): module.register(self) except ImportError as e: - logger.warning(f"Skipping twist base adapter {name}: {e}") + logger.warning(f"Skipping twist base adapter {entry}: {e}") twist_base_adapter_registry = TwistBaseAdapterRegistry() diff --git a/dimos/hardware/drive_trains/transport/adapter.py b/dimos/hardware/drive_trains/transport/adapter.py new file mode 100644 index 0000000000..5447b2eb93 --- /dev/null +++ b/dimos/hardware/drive_trains/transport/adapter.py @@ -0,0 +1,157 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport-based twist adapter — connects coordinator to a driver module via pub/sub. + +Topics derived from hardware_id: /{hardware_id}/cmd_vel, /{hardware_id}/odom. +""" + +from __future__ import annotations + +from functools import partial +import threading +from typing import TYPE_CHECKING, Any + +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.hardware.drive_trains.registry import TwistBaseAdapterRegistry + +logger = setup_logger() + +_ZERO_TWIST = Twist( + linear=Vector3(x=0.0, y=0.0, z=0.0), + angular=Vector3(x=0.0, y=0.0, z=0.0), +) + + +class TransportTwistAdapter: + """TwistBaseAdapter that publishes cmd_vel and subscribes to odom via pub/sub.""" + + def __init__( + self, + dof: int = 3, + hardware_id: str = "base", + transport_cls: type = LCMTransport, + **_: object, + ) -> None: + self._dof = dof + self._prefix = hardware_id + self._transport_cls = transport_cls + self._lock = threading.Lock() + self._last_velocities = [0.0] * dof + self._latest_odom: list[float] | None = None + self._cmd_vel_transport: Any = None + self._odom_transport: Any = None + self._odom_unsub: Any = None + self._connected = False + self._enabled = False + + def connect(self) -> bool: + cmd_vel_topic = f"/{self._prefix}/cmd_vel" + odom_topic = f"/{self._prefix}/odom" + + self._cmd_vel_transport = self._transport_cls(cmd_vel_topic, Twist) + self._odom_transport = self._transport_cls(odom_topic, PoseStamped) + self._odom_unsub = self._odom_transport.subscribe(self._on_odom) + + self._connected = True + logger.info(f"TransportTwistAdapter connected: cmd_vel={cmd_vel_topic}, odom={odom_topic}") + return True + + def disconnect(self) -> None: + self.write_stop() + + if self._odom_unsub is not None: + self._odom_unsub() + self._odom_unsub = None + + if self._cmd_vel_transport is not None: + self._cmd_vel_transport.stop() + self._cmd_vel_transport = None + if self._odom_transport is not None: + self._odom_transport.stop() + self._odom_transport = None + self._connected = False + self._enabled = False + with self._lock: + self._last_velocities = [0.0] * self._dof + self._latest_odom = None + + def is_connected(self) -> bool: + return self._connected + + def get_dof(self) -> int: + return self._dof + + def read_velocities(self) -> list[float]: + with self._lock: + return self._last_velocities.copy() + + def read_odometry(self) -> list[float] | None: + with self._lock: + if self._latest_odom is None: + return None + return self._latest_odom.copy() + + def write_velocities(self, velocities: list[float]) -> bool: + if len(velocities) != self._dof or self._cmd_vel_transport is None or not self._enabled: + return False + + with self._lock: + self._last_velocities = list(velocities) + + twist = Twist( + linear=Vector3(x=velocities[0], y=velocities[1] if self._dof > 1 else 0.0, z=0.0), + angular=Vector3(x=0.0, y=0.0, z=velocities[2] if self._dof > 2 else 0.0), + ) + self._cmd_vel_transport.publish(twist) + return True + + def write_stop(self) -> bool: + with self._lock: + self._last_velocities = [0.0] * self._dof + + if self._cmd_vel_transport is None: + return False + + self._cmd_vel_transport.publish(_ZERO_TWIST) + return True + + def write_enable(self, enable: bool) -> bool: + self._enabled = enable + if not enable: + self.write_stop() + return True + + def read_enabled(self) -> bool: + return self._enabled + + def _on_odom(self, msg: PoseStamped) -> None: + with self._lock: + self._latest_odom = [msg.x, msg.y, msg.yaw] + + +def register(registry: TwistBaseAdapterRegistry) -> None: + from dimos.core.transport import ROSTransport + + registry.register("transport_lcm", partial(TransportTwistAdapter, transport_cls=LCMTransport)) + registry.register("transport_ros", partial(TransportTwistAdapter, transport_cls=ROSTransport)) + + +__all__ = ["TransportTwistAdapter"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 6d77effe68..1fe034fd29 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -79,12 +79,14 @@ "unitree-go2-agentic-mcp": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic_mcp:unitree_go2_agentic_mcp", "unitree-go2-agentic-ollama": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic_ollama:unitree_go2_agentic_ollama", "unitree-go2-basic": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic:unitree_go2_basic", + "unitree-go2-coordinator": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_coordinator:unitree_go2_coordinator", "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", + "unitree-go2-webrtc-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_keyboard_teleop:unitree_go2_webrtc_keyboard_teleop", "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", diff --git a/dimos/robot/unitree/connection.py b/dimos/robot/unitree/connection.py index 7e60080f01..e410fd31f9 100644 --- a/dimos/robot/unitree/connection.py +++ b/dimos/robot/unitree/connection.py @@ -185,7 +185,7 @@ async def async_move_duration() -> None: self.stop_timer.cancel() # Auto-stop after 0.5 seconds if no new commands - self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop_movement) self.stop_timer.daemon = True self.stop_timer.start() @@ -195,7 +195,7 @@ async def async_move_duration() -> None: future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) future.result() # Stop after duration - self.stop() + self.stop_movement() else: # Single command for continuous movement future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) @@ -280,6 +280,7 @@ def standup(self) -> bool: return bool(self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]})) def balance_stand(self) -> bool: + """Activate BalanceStand mode — enables WIRELESS_CONTROLLER joystick commands.""" return bool( self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) ) @@ -290,6 +291,10 @@ def set_obstacle_avoidance(self, enabled: bool = True) -> None: {"api_id": 1001, "parameter": {"enable": int(enabled)}}, ) + def free_walk(self) -> bool: + """Activate FreeWalk locomotion mode — enables walking and velocity commands.""" + return bool(self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["FreeWalk"]})) + def liedown(self) -> bool: return bool( self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) @@ -362,17 +367,11 @@ def get_video_stream(self, fps: int = 30) -> Observable[Image]: """ return self.video_stream() # type: ignore[no-any-return] - def stop(self) -> bool: # type: ignore[no-redef] - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - # Cancel timer since we're explicitly stopping + def stop_movement(self) -> None: + """Cancel the auto-stop timer (used by move() for continuous commands).""" if self.stop_timer: self.stop_timer.cancel() self.stop_timer = None - return True def disconnect(self) -> None: """Disconnect from the robot and clean up resources.""" diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py new file mode 100644 index 0000000000..6a4ad79041 --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 ControlCoordinator — GO2Connection + coordinator via LCM transport adapter. + +Usage: + dimos run unitree-go2-coordinator + dimos --simulation run unitree-go2-coordinator +""" + +from __future__ import annotations + +from dimos.control.components import HardwareComponent, HardwareType, make_twist_base_joints +from dimos.control.coordinator import ControlCoordinator, TaskConfig +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.robot.unitree.go2.connection import GO2Connection + +_go2_joints = make_twist_base_joints("go2") + +unitree_go2_coordinator = ( + autoconnect( + GO2Connection.blueprint(), + ControlCoordinator.blueprint( + hardware=[ + HardwareComponent( + hardware_id="go2", + hardware_type=HardwareType.BASE, + joints=_go2_joints, + adapter_type="transport_lcm", + ), + ], + tasks=[ + TaskConfig( + name="vel_go2", + type="velocity", + joint_names=_go2_joints, + priority=10, + ), + ], + ), + ) + .remappings( + [ + (GO2Connection, "cmd_vel", "go2_cmd_vel"), + (GO2Connection, "odom", "go2_odom"), + ] + ) + .transports( + { + ("cmd_vel", Twist): LCMTransport("/cmd_vel", Twist), + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + ("go2_cmd_vel", Twist): LCMTransport("/go2/cmd_vel", Twist), + ("go2_odom", PoseStamped): LCMTransport("/go2/odom", PoseStamped), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } + ) + .global_config(obstacle_avoidance=False) +) + +__all__ = ["unitree_go2_coordinator"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py new file mode 100644 index 0000000000..dad4054fa9 --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 keyboard teleop via ControlCoordinator. + +Usage: + dimos run unitree-go2-webrtc-keyboard-teleop + dimos --simulation run unitree-go2-webrtc-keyboard-teleop +""" + +from __future__ import annotations + +from dimos.core.blueprints import autoconnect +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_coordinator import ( + unitree_go2_coordinator, +) +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop + +unitree_go2_webrtc_keyboard_teleop = autoconnect( + unitree_go2_coordinator, + KeyboardTeleop.blueprint(), +) + +__all__ = ["unitree_go2_webrtc_keyboard_teleop"] diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 212c7ac60a..0a792baf1a 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -40,7 +40,12 @@ def __init__( drift_compensation: list[float] | None = None, ) -> None: self._output_names = ["continuous_actions"] - self._policy = ort.InferenceSession(policy_path, providers=ort.get_available_providers()) + providers = ort.get_available_providers() + try: + self._policy = ort.InferenceSession(policy_path, providers=providers) + except RuntimeError: + logger.warning("GPU providers failed, falling back to CPUExecutionProvider") + self._policy = ort.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) logger.info(f"Loaded policy: {policy_path} with providers: {self._policy.get_providers()}") self._action_scale = action_scale From c24c51c37e2c4bad49155fe67e2fa62f2ed773c0 Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:17:09 -0700 Subject: [PATCH 26/42] data: add sim assets for xArm6 and Piper (#1642) --- data/.lfs/piper.tar.gz | 3 +++ data/.lfs/xarm6.tar.gz | 3 +++ 2 files changed, 6 insertions(+) create mode 100644 data/.lfs/piper.tar.gz create mode 100644 data/.lfs/xarm6.tar.gz diff --git a/data/.lfs/piper.tar.gz b/data/.lfs/piper.tar.gz new file mode 100644 index 0000000000..bf1adffac7 --- /dev/null +++ b/data/.lfs/piper.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63084f05db33ad09a448bd962a999a0a266cdb4c8731c500d327ecb392a32aee +size 7471394 diff --git a/data/.lfs/xarm6.tar.gz b/data/.lfs/xarm6.tar.gz new file mode 100644 index 0000000000..7a8cf7e531 --- /dev/null +++ b/data/.lfs/xarm6.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a4a9e06b6a97e6337b59f49c3b8da79756160a4ea641134caba636ab652a525 +size 1861919 From bb16ea2cfcda8d169c36c3d71dbdb164320c69a2 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 15:23:41 -0700 Subject: [PATCH 27/42] fix(unity-sim): use RerunBridgeModule.blueprint() after rerun_bridge rename After merging dev, the rerun_bridge function was renamed to run_bridge and changed to a standalone runner. Blueprint should use RerunBridgeModule.blueprint() like all other blueprints. --- dimos/simulation/unity/blueprint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index cceb3e697e..4dff253ca9 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -28,7 +28,7 @@ from dimos.core.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.simulation.unity.module import UnityBridgeModule -from dimos.visualization.rerun.bridge import _resolve_viewer_mode, rerun_bridge +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode def _rerun_blueprint() -> Any: @@ -57,5 +57,5 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), - rerun_bridge(viewer_mode=_resolve_viewer_mode(), **rerun_config), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) From 885b7291f932dfb7c32b5fc6d0eac7eb7dbd60b3 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 16:04:12 -0700 Subject: [PATCH 28/42] fix: update all_blueprints.py to include unity-bridge-module --- dimos/robot/all_blueprints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index a4c8c6248d..3912374123 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -167,6 +167,7 @@ "twist-teleop-module": "dimos.teleop.quest.quest_extensions", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container", + "unity-bridge-module": "dimos.simulation.unity.module", "vlm-agent": "dimos.agents.vlm_agent", "vlm-stream-tester": "dimos.agents.vlm_stream_tester", "voxel-grid-mapper": "dimos.mapping.voxels", From 49c514283ec5c0c9e52fbb3b6397b850ae58b64a Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 16:30:42 -0700 Subject: [PATCH 29/42] merge: pull latest dev From 0fe29f075d9d86f73f6920bb1a75056c424ab530 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 16:44:23 -0700 Subject: [PATCH 30/42] fix: remove @dataclass from UnityBridgeConfig (Pydantic compat), clean up code style - UnityBridgeConfig now inherits properly from ModuleConfig (Pydantic BaseModel) after dev changed ModuleConfig from dataclass to Pydantic - Use pydantic.Field instead of dataclasses.field for default_factory - Remove empty __init__.py (not allowed per test_no_init_files) - Remove section marker comments (not allowed per test_no_sections) - Merge latest dev (LFS assets for piper, xarm6) This broke after merging dev because ModuleConfig changed from @dataclass to Pydantic BaseModel. The @dataclass decorator on UnityBridgeConfig was generating an __init__ that required all fields as positional args, ignoring the Pydantic field defaults. --- dimos/simulation/unity/__init__.py | 0 dimos/simulation/unity/module.py | 23 ++--------------------- dimos/simulation/unity/test_unity_sim.py | 16 ---------------- dimos/utils/ros1.py | 12 ------------ 4 files changed, 2 insertions(+), 49 deletions(-) delete mode 100644 dimos/simulation/unity/__init__.py diff --git a/dimos/simulation/unity/__init__.py b/dimos/simulation/unity/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index d16420495a..324de377da 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -29,7 +29,6 @@ from __future__ import annotations -from dataclasses import dataclass, field import json import math import os @@ -44,6 +43,7 @@ from typing import Any import numpy as np +from pydantic import Field from reactivex.disposable import Disposable from dimos.core.core import rpc @@ -75,9 +75,7 @@ _SUPPORTED_ARCHS = {"x86_64", "AMD64"} -# --------------------------------------------------------------------------- # TCP protocol helpers -# --------------------------------------------------------------------------- def _recvall(sock: socket.socket, size: int) -> bytes: @@ -118,9 +116,7 @@ def _write_tcp_command(sock: socket.socket, command: str, params: dict[str, Any] ) -# --------------------------------------------------------------------------- # Platform validation -# --------------------------------------------------------------------------- def _validate_platform() -> None: @@ -142,12 +138,9 @@ def _validate_platform() -> None: ) -# --------------------------------------------------------------------------- # Config -# --------------------------------------------------------------------------- -@dataclass class UnityBridgeConfig(ModuleConfig): """Configuration for the Unity bridge / vehicle simulator. @@ -172,7 +165,7 @@ class UnityBridgeConfig(ModuleConfig): headless: bool = False # Extra CLI args to pass to the Unity binary. - unity_extra_args: list[str] = field(default_factory=list) + unity_extra_args: list[str] = Field(default_factory=list) # Vehicle parameters vehicle_height: float = 0.75 @@ -187,9 +180,7 @@ class UnityBridgeConfig(ModuleConfig): sim_rate: float = 200.0 -# --------------------------------------------------------------------------- # Module -# --------------------------------------------------------------------------- class UnityBridgeModule(Module[UnityBridgeConfig]): @@ -250,8 +241,6 @@ def rerun_suppress_camera_info(_: Any) -> None: """Suppress CameraInfo logging — the static pinhole handles 3D projection.""" return None - # ---- lifecycle -------------------------------------------------------- - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) self._x = self.config.init_x @@ -332,8 +321,6 @@ def stop(self) -> None: self._unity_process = None super().stop() - # ---- Unity process management ----------------------------------------- - def _resolve_binary(self) -> Path | None: """Find the Unity binary from config or LFS data. @@ -414,8 +401,6 @@ def _launch_unity(self) -> None: f"The binary may still be loading — it will connect when ready." ) - # ---- input callbacks -------------------------------------------------- - def _on_cmd_vel(self, twist: Twist) -> None: with self._cmd_lock: self._fwd_speed = twist.linear.x @@ -433,8 +418,6 @@ def _on_terrain(self, cloud: PointCloud2) -> None: with self._state_lock: self._terrain_z = 0.8 * self._terrain_z + 0.2 * near[:, 2].mean() - # ---- Unity TCP bridge ------------------------------------------------- - def _unity_loop(self) -> None: server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -589,8 +572,6 @@ def _send_to_unity(self, topic: str, data: bytes) -> None: if connected: self._send_queue.put((topic, data)) - # ---- kinematic sim loop ----------------------------------------------- - def _sim_loop(self) -> None: dt = 1.0 / self.config.sim_rate diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index 31f1237f51..aee84199b9 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -46,9 +46,7 @@ _has_display = bool(os.environ.get("DISPLAY")) -# --------------------------------------------------------------------------- # Helpers -# --------------------------------------------------------------------------- class _MockTransport: @@ -131,9 +129,7 @@ def _recv_tcp(sock) -> tuple[str, bytes]: return d, buf -# --------------------------------------------------------------------------- # Config & Platform — fast, runs everywhere -# --------------------------------------------------------------------------- class TestConfig: @@ -164,9 +160,7 @@ def test_rejects_unsupported_platform(self): _validate_platform() -# --------------------------------------------------------------------------- # Pickle — fast, runs everywhere -# --------------------------------------------------------------------------- class TestPickle: @@ -179,9 +173,7 @@ def test_module_survives_pickle(self): m2.stop() -# --------------------------------------------------------------------------- # ROS1 Deserialization — fast, runs everywhere -# --------------------------------------------------------------------------- class TestROS1Deserialization: @@ -195,9 +187,7 @@ def test_pointcloud2_round_trip(self): assert frame_id == "map" -# --------------------------------------------------------------------------- # TCP Bridge — needs sockets, ~1s, runs everywhere -# --------------------------------------------------------------------------- class TestTCPBridge: @@ -232,9 +222,7 @@ def test_handshake_and_data_flow(self): np.testing.assert_allclose(received_pts, pts, atol=0.01) -# --------------------------------------------------------------------------- # Kinematic Sim — needs threading, ~1s, runs everywhere -# --------------------------------------------------------------------------- class TestKinematicSim: @@ -270,9 +258,7 @@ def test_cmd_vel_moves_robot(self): assert last_odom.x > 0.5 -# --------------------------------------------------------------------------- # Rerun Config — fast, runs everywhere -# --------------------------------------------------------------------------- class TestRerunConfig: @@ -287,10 +273,8 @@ def test_suppress_returns_none(self): assert UnityBridgeModule.rerun_suppress_camera_info(None) is None -# --------------------------------------------------------------------------- # Live Unity — slow, requires Linux x86_64 + DISPLAY # These are skipped in CI and on unsupported platforms. -# --------------------------------------------------------------------------- @pytest.mark.slow diff --git a/dimos/utils/ros1.py b/dimos/utils/ros1.py index b3c6c43456..9053ce20bc 100644 --- a/dimos/utils/ros1.py +++ b/dimos/utils/ros1.py @@ -41,9 +41,7 @@ import numpy as np -# --------------------------------------------------------------------------- # Low-level readers -# --------------------------------------------------------------------------- class ROS1Reader: @@ -104,9 +102,7 @@ def remaining(self) -> int: return len(self.data) - self.off -# --------------------------------------------------------------------------- # Low-level writer -# --------------------------------------------------------------------------- class ROS1Writer: @@ -153,9 +149,7 @@ def bytes(self) -> bytes: return bytes(self.buf) -# --------------------------------------------------------------------------- # Header (std_msgs/Header) -# --------------------------------------------------------------------------- @dataclass @@ -180,9 +174,7 @@ def write_header( w.string(frame_id) -# --------------------------------------------------------------------------- # sensor_msgs/PointCloud2 -# --------------------------------------------------------------------------- @dataclass @@ -266,9 +258,7 @@ def deserialize_pointcloud2(data: bytes) -> tuple[np.ndarray, str, float] | None return None -# --------------------------------------------------------------------------- # sensor_msgs/CompressedImage -# --------------------------------------------------------------------------- def deserialize_compressed_image(data: bytes) -> tuple[bytes, str, str, float] | None: @@ -288,9 +278,7 @@ def deserialize_compressed_image(data: bytes) -> tuple[bytes, str, str, float] | return None -# --------------------------------------------------------------------------- # geometry_msgs/PoseStamped (serialize) -# --------------------------------------------------------------------------- def serialize_pose_stamped( From 7e69093213714bd0d96433d75e5792542f324fd5 Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Sat, 21 Mar 2026 18:07:11 -0700 Subject: [PATCH 31/42] MuJoCo sim support for Manipulation (#1639) --- dimos/control/blueprints/_hardware.py | 62 ++++- dimos/control/blueprints/teleop.py | 32 ++- dimos/control/components.py | 2 + dimos/control/coordinator.py | 1 + dimos/e2e_tests/test_simulation_module.py | 86 ------- dimos/hardware/manipulators/sim/adapter.py | 70 ++++++ dimos/robot/all_blueprints.py | 4 +- dimos/simulation/engines/base.py | 16 ++ dimos/simulation/engines/mujoco_engine.py | 33 +++ .../manipulators/sim_manip_interface.py | 45 +++- dimos/simulation/manipulators/sim_module.py | 234 ------------------ .../manipulators/test_sim_adapter.py | 184 ++++++++++++++ .../manipulators/test_sim_module.py | 124 ---------- dimos/simulation/sim_blueprints.py | 45 ---- dimos/teleop/quest/blueprints.py | 16 ++ 15 files changed, 442 insertions(+), 512 deletions(-) delete mode 100644 dimos/e2e_tests/test_simulation_module.py create mode 100644 dimos/hardware/manipulators/sim/adapter.py delete mode 100644 dimos/simulation/manipulators/sim_module.py create mode 100644 dimos/simulation/manipulators/test_sim_adapter.py delete mode 100644 dimos/simulation/manipulators/test_sim_module.py delete mode 100644 dimos/simulation/sim_blueprints.py diff --git a/dimos/control/blueprints/_hardware.py b/dimos/control/blueprints/_hardware.py index a36027865a..7e72628f0b 100644 --- a/dimos/control/blueprints/_hardware.py +++ b/dimos/control/blueprints/_hardware.py @@ -34,6 +34,11 @@ XARM6_MODEL_PATH = LfsPath("xarm_description/urdf/xarm6/xarm6.urdf") XARM7_MODEL_PATH = LfsPath("xarm_description/urdf/xarm7/xarm7.urdf") +# Simulation model paths (MJCF) +XARM7_SIM_PATH = LfsPath("xarm7/scene.xml") +XARM6_SIM_PATH = LfsPath("xarm6/scene.xml") +PIPER_SIM_PATH = LfsPath("piper/scene.xml") + def mock_arm(hw_id: str = "arm", n_joints: int = 7) -> HardwareComponent: """Mock manipulator (no real hardware).""" @@ -46,7 +51,9 @@ def mock_arm(hw_id: str = "arm", n_joints: int = 7) -> HardwareComponent: def xarm7(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: - """XArm7 real hardware (7-DOF).""" + """XArm7 (7-DOF). Uses sim when --simulation flag is set.""" + if global_config.simulation: + return sim_xarm7(hw_id, headless=False, gripper=gripper) return HardwareComponent( hardware_id=hw_id, hardware_type=HardwareType.MANIPULATOR, @@ -59,7 +66,9 @@ def xarm7(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: def xarm6(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: - """XArm6 real hardware (6-DOF).""" + """XArm6 (6-DOF). Uses sim when --simulation flag is set.""" + if global_config.simulation: + return sim_xarm6(hw_id, headless=False, gripper=gripper) return HardwareComponent( hardware_id=hw_id, hardware_type=HardwareType.MANIPULATOR, @@ -71,8 +80,10 @@ def xarm6(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: ) -def piper(hw_id: str = "arm") -> HardwareComponent: - """Piper arm (6-DOF, CAN bus).""" +def piper(hw_id: str = "arm", *, gripper: bool = False) -> HardwareComponent: + """Piper arm (6-DOF, CAN bus). Uses sim when --simulation flag is set.""" + if global_config.simulation: + return sim_piper(hw_id, headless=False, gripper=gripper) return HardwareComponent( hardware_id=hw_id, hardware_type=HardwareType.MANIPULATOR, @@ -80,6 +91,7 @@ def piper(hw_id: str = "arm") -> HardwareComponent: adapter_type="piper", address=CAN_PORT, auto_enable=True, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], ) @@ -91,3 +103,45 @@ def mock_twist_base(hw_id: str = "base") -> HardwareComponent: joints=make_twist_base_joints(hw_id), adapter_type="mock_twist_base", ) + + +def sim_xarm7( + hw_id: str = "arm", *, headless: bool = True, gripper: bool = False +) -> HardwareComponent: + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 7), + adapter_type="sim_mujoco", + address=str(XARM7_SIM_PATH), + adapter_kwargs={"headless": headless}, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], + ) + + +def sim_xarm6( + hw_id: str = "arm", *, headless: bool = True, gripper: bool = False +) -> HardwareComponent: + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 6), + adapter_type="sim_mujoco", + address=str(XARM6_SIM_PATH), + adapter_kwargs={"headless": headless}, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], + ) + + +def sim_piper( + hw_id: str = "arm", *, headless: bool = True, gripper: bool = False +) -> HardwareComponent: + return HardwareComponent( + hardware_id=hw_id, + hardware_type=HardwareType.MANIPULATOR, + joints=make_joints(hw_id, 6), + adapter_type="sim_mujoco", + address=str(PIPER_SIM_PATH), + adapter_kwargs={"headless": headless}, + gripper_joints=make_gripper_joints(hw_id) if gripper else [], + ) diff --git a/dimos/control/blueprints/teleop.py b/dimos/control/blueprints/teleop.py index 1dfa55d80d..a6266efc09 100644 --- a/dimos/control/blueprints/teleop.py +++ b/dimos/control/blueprints/teleop.py @@ -43,8 +43,8 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import Buttons -# XArm6 teleop - streaming position control -coordinator_teleop_xarm6 = ControlCoordinator.blueprint( +# XArm6 servo - streaming position control +coordinator_servo_xarm6 = ControlCoordinator.blueprint( hardware=[xarm6()], tasks=[ TaskConfig( @@ -200,6 +200,30 @@ } ) +# Single XArm6 with TeleopIK +coordinator_teleop_xarm6 = ControlCoordinator.blueprint( + hardware=[xarm6()], + tasks=[ + TaskConfig( + name="teleop_xarm", + type="teleop_ik", + joint_names=[f"arm_joint{i + 1}" for i in range(6)], + priority=10, + model_path=XARM6_MODEL_PATH, + ee_joint_id=6, + hand="right", + ), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + } +) + # Dual arm teleop: XArm6 + Piper with TeleopIK coordinator_teleop_dual = ControlCoordinator.blueprint( hardware=[xarm6("xarm_arm"), piper("piper_arm")], @@ -235,15 +259,13 @@ __all__ = [ - # Cartesian IK "coordinator_cartesian_ik_mock", "coordinator_cartesian_ik_piper", "coordinator_combined_xarm6", + "coordinator_servo_xarm6", "coordinator_teleop_dual", "coordinator_teleop_piper", - # Servo / Velocity "coordinator_teleop_xarm6", - # TeleopIK "coordinator_teleop_xarm7", "coordinator_velocity_xarm6", ] diff --git a/dimos/control/components.py b/dimos/control/components.py index a32b395b83..42b5d9c9ea 100644 --- a/dimos/control/components.py +++ b/dimos/control/components.py @@ -16,6 +16,7 @@ from dataclasses import dataclass, field from enum import Enum +from typing import Any HardwareId = str JointName = str @@ -57,6 +58,7 @@ class HardwareComponent: address: str | None = None auto_enable: bool = True gripper_joints: list[JointName] = field(default_factory=list) + adapter_kwargs: dict[str, Any] = field(default_factory=dict) @property def all_joints(self) -> list[JointName]: diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 78184d8272..ef5655036f 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -238,6 +238,7 @@ def _create_adapter(self, component: HardwareComponent) -> ManipulatorAdapter: dof=len(component.joints), address=component.address, hardware_id=component.hardware_id, + **component.adapter_kwargs, ) def _create_twist_base_adapter(self, component: HardwareComponent) -> TwistBaseAdapter: diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py deleted file mode 100644 index e08183fc24..0000000000 --- a/dimos/e2e_tests/test_simulation_module.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End-to-end tests for the simulation module.""" - -import pytest - -from dimos.msgs.sensor_msgs.JointCommand import JointCommand -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.msgs.sensor_msgs.RobotState import RobotState - - -def _positions_within_tolerance( - positions: list[float], - target: list[float], - tolerance: float, -) -> bool: - if len(positions) < len(target): - return False - return all(abs(positions[i] - target[i]) <= tolerance for i in range(len(target))) - - -@pytest.mark.skipif_in_ci -@pytest.mark.slow -class TestSimulationModuleE2E: - def test_xarm7_joint_state_published(self, lcm_spy, start_blueprint) -> None: - joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" - lcm_spy.save_topic(joint_state_topic) - - start_blueprint("xarm7-trajectory-sim") - lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=15.0) - - with lcm_spy._messages_lock: - raw_joint_state = lcm_spy.messages[joint_state_topic][0] - - joint_state = JointState.lcm_decode(raw_joint_state) - assert len(joint_state.name) == 8 - assert len(joint_state.position) == 8 - - def test_xarm7_robot_state_published(self, lcm_spy, start_blueprint) -> None: - robot_state_topic = "/xarm/robot_state#sensor_msgs.RobotState" - lcm_spy.save_topic(robot_state_topic) - - start_blueprint("xarm7-trajectory-sim") - lcm_spy.wait_for_saved_topic(robot_state_topic, timeout=15.0) - - with lcm_spy._messages_lock: - raw_robot_state = lcm_spy.messages[robot_state_topic][0] - - robot_state = RobotState.lcm_decode(raw_robot_state) - assert robot_state.mt_able in (0, 1) - - def test_xarm7_joint_command_updates_joint_state(self, lcm_spy, start_blueprint) -> None: - joint_state_topic = "/xarm/joint_states#sensor_msgs.JointState" - joint_command_topic = "/xarm/joint_position_command#sensor_msgs.JointCommand" - lcm_spy.save_topic(joint_state_topic) - - start_blueprint("xarm7-trajectory-sim") - lcm_spy.wait_for_saved_topic(joint_state_topic, timeout=15.0) - - target_positions = [0.2, -0.2, 0.1, -0.1, 0.15, -0.15, 0.05] - lcm_spy.publish(joint_command_topic, JointCommand(positions=target_positions)) - - tolerance = 0.03 - lcm_spy.wait_for_message_result( - joint_state_topic, - JointState, - predicate=lambda msg: _positions_within_tolerance( - list(msg.position), - target_positions, - tolerance, - ), - fail_message=("joint_state did not reach commanded positions within tolerance"), - timeout=10.0, - ) diff --git a/dimos/hardware/manipulators/sim/adapter.py b/dimos/hardware/manipulators/sim/adapter.py new file mode 100644 index 0000000000..3979ce98c5 --- /dev/null +++ b/dimos/hardware/manipulators/sim/adapter.py @@ -0,0 +1,70 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo simulation adapter for ControlCoordinator integration. + +Thin wrapper around SimManipInterface that plugs into the adapter registry. +Arm joint methods are inherited from SimManipInterface. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from dimos.simulation.engines.mujoco_engine import MujocoEngine +from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface + +if TYPE_CHECKING: + from dimos.hardware.manipulators.registry import AdapterRegistry + + +class SimMujocoAdapter(SimManipInterface): + """Uses ``address`` as the MJCF XML path (same field real adapters use for IP/port). + If the engine has more joints than ``dof``, the extra joint at index ``dof`` + is treated as the gripper, with ctrl range scaled automatically. + """ + + def __init__( + self, + dof: int = 7, + address: str | None = None, + headless: bool = True, + **_: Any, + ) -> None: + if address is None: + raise ValueError("address (MJCF XML path) is required for sim_mujoco adapter") + engine = MujocoEngine(config_path=Path(address), headless=headless) + + # Detect gripper from engine joints + gripper_idx = None + gripper_kwargs = {} + joint_names = list(engine.joint_names) + if len(joint_names) > dof: + gripper_idx = dof + ctrl_range = engine.get_actuator_ctrl_range(dof) + joint_range = engine.get_joint_range(dof) + if ctrl_range is None or joint_range is None: + raise ValueError(f"Gripper joint at index {dof} missing ctrl/joint range in MJCF") + gripper_kwargs = {"gripper_ctrl_range": ctrl_range, "gripper_joint_range": joint_range} + + super().__init__(engine=engine, dof=dof, gripper_idx=gripper_idx, **gripper_kwargs) + + +def register(registry: AdapterRegistry) -> None: + """Register this adapter with the registry.""" + registry.register("sim_mujoco", SimMujocoAdapter) + + +__all__ = ["SimMujocoAdapter"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 1fe034fd29..03a1f47f2a 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -27,6 +27,7 @@ "coordinator-mock-twist-base": "dimos.control.blueprints.mobile:coordinator_mock_twist_base", "coordinator-piper": "dimos.control.blueprints.basic:coordinator_piper", "coordinator-piper-xarm": "dimos.control.blueprints.dual:coordinator_piper_xarm", + "coordinator-servo-xarm6": "dimos.control.blueprints.teleop:coordinator_servo_xarm6", "coordinator-teleop-dual": "dimos.control.blueprints.teleop:coordinator_teleop_dual", "coordinator-teleop-piper": "dimos.control.blueprints.teleop:coordinator_teleop_piper", "coordinator-teleop-xarm6": "dimos.control.blueprints.teleop:coordinator_teleop_xarm6", @@ -61,6 +62,7 @@ "teleop-quest-dual": "dimos.teleop.quest.blueprints:teleop_quest_dual", "teleop-quest-piper": "dimos.teleop.quest.blueprints:teleop_quest_piper", "teleop-quest-rerun": "dimos.teleop.quest.blueprints:teleop_quest_rerun", + "teleop-quest-xarm6": "dimos.teleop.quest.blueprints:teleop_quest_xarm6", "teleop-quest-xarm7": "dimos.teleop.quest.blueprints:teleop_quest_xarm7", "uintree-g1-primitive-no-nav": "dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav:uintree_g1_primitive_no_nav", "unitree-g1": "dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1:unitree_g1", @@ -92,7 +94,6 @@ "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", "xarm7-planner-coordinator": "dimos.manipulation.blueprints:xarm7_planner_coordinator", "xarm7-planner-coordinator-agent": "dimos.manipulation.blueprints:xarm7_planner_coordinator_agent", - "xarm7-trajectory-sim": "dimos.simulation.sim_blueprints:xarm7_trajectory_sim", } @@ -159,7 +160,6 @@ "rerun-bridge-module": "dimos.visualization.rerun.bridge", "ros-nav": "dimos.navigation.rosnav", "simple-phone-teleop": "dimos.teleop.phone.phone_extensions", - "simulation-module": "dimos.simulation.manipulators.sim_module", "spatial-memory": "dimos.perception.spatial_perception", "speak-skill": "dimos.agents.skills.speak_skill", "temporal-memory": "dimos.perception.experimental.temporal_memory.temporal_memory", diff --git a/dimos/simulation/engines/base.py b/dimos/simulation/engines/base.py index 58e76ecba6..d4b0735528 100644 --- a/dimos/simulation/engines/base.py +++ b/dimos/simulation/engines/base.py @@ -82,3 +82,19 @@ def write_joint_command(self, command: JointState) -> None: @abstractmethod def hold_current_position(self) -> None: """Hold current joint positions.""" + + @abstractmethod + def set_position_target(self, joint_idx: int, value: float) -> None: + """Set position target for a single joint/actuator by index.""" + + @abstractmethod + def get_position_target(self, joint_idx: int) -> float: + """Get current position target for a single joint/actuator by index.""" + + def get_actuator_ctrl_range(self, actuator_idx: int) -> tuple[float, float] | None: + """Get (min, max) ctrl range for an actuator. None if not available.""" + return None + + def get_joint_range(self, joint_idx: int) -> tuple[float, float] | None: + """Get (min, max) position range for a joint. None if not available.""" + return None diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py index 2d1cdf92ac..df8359746a 100644 --- a/dimos/simulation/engines/mujoco_engine.py +++ b/dimos/simulation/engines/mujoco_engine.py @@ -288,12 +288,45 @@ def _set_effort_targets(self, efforts: list[float]) -> None: for i in range(len(efforts)): self._joint_effort_targets[i] = float(efforts[i]) + def set_position_target(self, index: int, value: float) -> None: + with self._lock: + self._joint_position_targets[index] = float(value) + + def get_position_target(self, index: int) -> float: + with self._lock: + return float(self._joint_position_targets[index]) + def hold_current_position(self) -> None: with self._lock: self._command_mode = "position" for i, mapping in enumerate(self._joint_mappings): self._joint_position_targets[i] = self._current_position(mapping) + def get_actuator_ctrl_range(self, joint_index: int) -> tuple[float, float] | None: + mapping = self._joint_mappings[joint_index] + if mapping.actuator_id is None: + return None + lo = float(self._model.actuator_ctrlrange[mapping.actuator_id, 0]) + hi = float(self._model.actuator_ctrlrange[mapping.actuator_id, 1]) + return (lo, hi) + + def get_joint_range(self, joint_index: int) -> tuple[float, float] | None: + mapping = self._joint_mappings[joint_index] + if mapping.tendon_qpos_adrs: + first_adr = mapping.tendon_qpos_adrs[0] + for jid in range(self._model.njnt): + if self._model.jnt_qposadr[jid] == first_adr: + return ( + float(self._model.jnt_range[jid, 0]), + float(self._model.jnt_range[jid, 1]), + ) + if mapping.joint_id is not None: + return ( + float(self._model.jnt_range[mapping.joint_id, 0]), + float(self._model.jnt_range[mapping.joint_id, 1]), + ) + return None + __all__ = [ "MujocoEngine", diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py index 6de570ae15..07e56c5afd 100644 --- a/dimos/simulation/manipulators/sim_manip_interface.py +++ b/dimos/simulation/manipulators/sim_manip_interface.py @@ -30,16 +30,26 @@ class SimManipInterface: """Adapter wrapper around a simulation engine to provide a uniform manipulator API.""" - def __init__(self, engine: SimulationEngine) -> None: + def __init__( + self, + engine: SimulationEngine, + dof: int | None = None, + gripper_idx: int | None = None, + gripper_ctrl_range: tuple[float, float] = (0.0, 1.0), + gripper_joint_range: tuple[float, float] = (0.0, 1.0), + ) -> None: self.logger = logging.getLogger(self.__class__.__name__) self._engine = engine self._joint_names = list(engine.joint_names) - self._dof = len(self._joint_names) + self._dof = dof if dof is not None else len(self._joint_names) self._connected = False self._servos_enabled = False self._control_mode = ControlMode.POSITION self._error_code = 0 self._error_message = "" + self._gripper_idx = gripper_idx + self._gripper_ctrl_range = gripper_ctrl_range + self._gripper_joint_range = gripper_joint_range def connect(self) -> bool: """Connect to the simulation engine.""" @@ -51,8 +61,6 @@ def connect(self) -> bool: if self._engine.connected: self._connected = True self._servos_enabled = True - self._joint_names = list(self._engine.joint_names) - self._dof = len(self._joint_names) self.logger.info( "Successfully connected to simulation", extra={"dof": self._dof}, @@ -64,14 +72,14 @@ def connect(self) -> bool: self.logger.error(f"Sim connection failed: {exc}") return False - def disconnect(self) -> bool: + def disconnect(self) -> None: """Disconnect from simulation.""" try: - return self._engine.disconnect() + self._engine.disconnect() except Exception as exc: - self._connected = False self.logger.error(f"Sim disconnection failed: {exc}") - return False + finally: + self._connected = False def is_connected(self) -> bool: return bool(self._connected and self._engine.connected) @@ -135,7 +143,7 @@ def read_state(self) -> dict[str, int]: def read_error(self) -> tuple[int, str]: return self._error_code, self._error_message - def write_joint_positions(self, positions: list[float]) -> bool: + def write_joint_positions(self, positions: list[float], velocity: float = 1.0) -> bool: if not self._servos_enabled: return False self._control_mode = ControlMode.POSITION @@ -185,11 +193,24 @@ def write_cartesian_position( return False def read_gripper_position(self) -> float | None: - return None + if self._gripper_idx is None: + return None + positions = self._engine.read_joint_positions() + return positions[self._gripper_idx] def write_gripper_position(self, position: float) -> bool: - _ = position - return False + if self._gripper_idx is None: + return False + jlo, jhi = self._gripper_joint_range + clo, chi = self._gripper_ctrl_range + position = max(jlo, min(jhi, position)) + if jhi != jlo: + t = (position - jlo) / (jhi - jlo) + ctrl_value = chi - t * (chi - clo) + else: + ctrl_value = clo + self._engine.set_position_target(self._gripper_idx, ctrl_value) + return True def read_force_torque(self) -> list[float] | None: return None diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py deleted file mode 100644 index 66a2b5d888..0000000000 --- a/dimos/simulation/manipulators/sim_module.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simulator-agnostic manipulator simulation module.""" - -from collections.abc import Callable -from pathlib import Path -import threading -import time -from typing import Any - -from reactivex.disposable import Disposable - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs.JointCommand import JointCommand -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.msgs.sensor_msgs.RobotState import RobotState -from dimos.simulation.engines.registry import EngineType, get_engine -from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface - - -class SimulationModuleConfig(ModuleConfig): - engine: EngineType - config_path: Path | Callable[[], Path] - headless: bool = False - - -class SimulationModule(Module[SimulationModuleConfig]): - """Module wrapper for manipulator simulation across engines.""" - - default_config = SimulationModuleConfig - - joint_state: Out[JointState] - robot_state: Out[RobotState] - joint_position_command: In[JointCommand] - joint_velocity_command: In[JointCommand] - - MIN_CONTROL_RATE = 1.0 - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._backend: SimManipInterface | None = None - self._control_rate = 100.0 - self._monitor_rate = 100.0 - self._joint_prefix = "joint" - self._stop_event = threading.Event() - self._control_thread: threading.Thread | None = None - self._monitor_thread: threading.Thread | None = None - self._command_lock = threading.Lock() - self._pending_positions: list[float] | None = None - self._pending_velocities: list[float] | None = None - - def _create_backend(self) -> SimManipInterface: - engine_cls = get_engine(self.config.engine) - config_path = ( - self.config.config_path() - if callable(self.config.config_path) - else self.config.config_path - ) - engine = engine_cls( - config_path=config_path, - headless=self.config.headless, - ) - return SimManipInterface(engine=engine) - - @rpc - def start(self) -> None: - super().start() - if self._backend is None: - self._backend = self._create_backend() - if not self._backend.connect(): - raise RuntimeError("Failed to connect to simulation backend") - self._backend.write_enable(True) - - self._disposables.add( - Disposable(self.joint_position_command.subscribe(self._on_joint_position_command)) - ) - self._disposables.add( - Disposable(self.joint_velocity_command.subscribe(self._on_joint_velocity_command)) - ) - - self._stop_event.clear() - self._control_thread = threading.Thread( - target=self._control_loop, - daemon=True, - name=f"{self.__class__.__name__}-control", - ) - self._monitor_thread = threading.Thread( - target=self._monitor_loop, - daemon=True, - name=f"{self.__class__.__name__}-monitor", - ) - self._control_thread.start() - self._monitor_thread.start() - - @rpc - def stop(self) -> None: - self._stop_event.set() - if self._control_thread and self._control_thread.is_alive(): - self._control_thread.join(timeout=2.0) - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=2.0) - if self._backend: - self._backend.disconnect() - super().stop() - - @rpc - def enable_servos(self) -> bool: - if not self._backend: - return False - return self._backend.write_enable(True) - - @rpc - def disable_servos(self) -> bool: - if not self._backend: - return False - return self._backend.write_enable(False) - - @rpc - def clear_errors(self) -> bool: - if not self._backend: - return False - return self._backend.write_clear_errors() - - @rpc - def emergency_stop(self) -> bool: - if not self._backend: - return False - return self._backend.write_stop() - - def _on_joint_position_command(self, msg: JointCommand) -> None: - with self._command_lock: - self._pending_positions = list(msg.positions) - self._pending_velocities = None - - def _on_joint_velocity_command(self, msg: JointCommand) -> None: - with self._command_lock: - self._pending_velocities = list(msg.positions) - self._pending_positions = None - - def _control_loop(self) -> None: - period = 1.0 / max(self._control_rate, self.MIN_CONTROL_RATE) - next_tick = time.monotonic() # monotonic time used to avoid time drift - while not self._stop_event.is_set(): - with self._command_lock: - positions = ( - None if self._pending_positions is None else list(self._pending_positions) - ) - velocities = ( - None if self._pending_velocities is None else list(self._pending_velocities) - ) - - if self._backend: - if positions is not None: - self._backend.write_joint_positions(positions) - elif velocities is not None: - self._backend.write_joint_velocities(velocities) - dof = self._backend.get_dof() - names = self._resolve_joint_names(dof) - positions = self._backend.read_joint_positions() - velocities = self._backend.read_joint_velocities() - efforts = self._backend.read_joint_efforts() - self.joint_state.publish( - JointState( - frame_id=self.frame_id, - name=names, - position=positions, - velocity=velocities, - effort=efforts, - ) - ) - next_tick += period - sleep_for = next_tick - time.monotonic() - if sleep_for > 0: - if self._stop_event.wait(sleep_for): - break - else: - next_tick = time.monotonic() - - def _monitor_loop(self) -> None: - period = 1.0 / max(self._monitor_rate, self.MIN_CONTROL_RATE) - next_tick = time.monotonic() # monotonic time used to avoid time drift - while not self._stop_event.is_set(): - if not self._backend: - pass - else: - dof = self._backend.get_dof() - self._resolve_joint_names(dof) - positions = self._backend.read_joint_positions() - self._backend.read_joint_velocities() - self._backend.read_joint_efforts() - state = self._backend.read_state() - error_code, _ = self._backend.read_error() - self.robot_state.publish( - RobotState( - state=state.get("state", 0), - mode=state.get("mode", 0), - error_code=error_code, - warn_code=0, - cmdnum=0, - mt_brake=0, - mt_able=1 if self._backend.read_enabled() else 0, - tcp_pose=[], - tcp_offset=[], - joints=[float(p) for p in positions], - ) - ) - next_tick += period - sleep_for = next_tick - time.monotonic() - if sleep_for > 0: - if self._stop_event.wait(sleep_for): - break - else: - next_tick = time.monotonic() - - def _resolve_joint_names(self, dof: int) -> list[str]: - if self._backend: - names = self._backend.get_joint_names() - if len(names) >= dof: - return list(names[:dof]) - return [f"{self._joint_prefix}{i + 1}" for i in range(dof)] diff --git a/dimos/simulation/manipulators/test_sim_adapter.py b/dimos/simulation/manipulators/test_sim_adapter.py new file mode 100644 index 0000000000..8f253229f0 --- /dev/null +++ b/dimos/simulation/manipulators/test_sim_adapter.py @@ -0,0 +1,184 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SimMujocoAdapter and gripper integration.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from dimos.hardware.manipulators.sim.adapter import SimMujocoAdapter, register +from dimos.simulation.utils.xml_parser import JointMapping + +ARM_DOF = 7 + + +def _make_joint_mapping(name: str, idx: int) -> JointMapping: + """Create a JointMapping for a simple revolute joint.""" + return JointMapping( + name=name, + joint_id=idx, + actuator_id=idx, + qpos_adr=idx, + dof_adr=idx, + tendon_qpos_adrs=(), + tendon_dof_adrs=(), + ) + + +def _make_gripper_mapping(name: str, idx: int) -> JointMapping: + """Create a JointMapping for a tendon-driven gripper.""" + return JointMapping( + name=name, + joint_id=None, + actuator_id=idx, + qpos_adr=None, + dof_adr=None, + tendon_qpos_adrs=(idx, idx + 1), + tendon_dof_adrs=(idx, idx + 1), + ) + + +def _patch_mujoco_engine(n_joints: int): + """Patch only the MuJoCo C-library and filesystem boundaries. + + Mocks ``_resolve_xml_path``, ``MjModel.from_xml_path``, ``MjData``, and + ``build_joint_mappings`` — the rest of ``MujocoEngine.__init__`` runs as-is. + """ + mappings = [_make_joint_mapping(f"joint{i}", i) for i in range(ARM_DOF)] + if n_joints > ARM_DOF: + mappings.append(_make_gripper_mapping(f"joint{ARM_DOF}", ARM_DOF)) + + fake_model = MagicMock() + fake_model.opt.timestep = 0.002 + fake_model.nu = n_joints + fake_model.nq = n_joints + fake_model.njnt = n_joints + fake_model.actuator_ctrlrange = np.array( + [[-6.28, 6.28]] * ARM_DOF + ([[0.0, 255.0]] if n_joints > ARM_DOF else []) + ) + fake_model.jnt_range = np.array( + [[-6.28, 6.28]] * ARM_DOF + ([[0.0, 0.85]] if n_joints > ARM_DOF else []) + ) + fake_model.jnt_qposadr = np.arange(n_joints) + + fake_data = MagicMock() + fake_data.qpos = np.zeros(n_joints + 4) # extra for tendon qpos addresses + fake_data.actuator_length = np.zeros(n_joints) + + patches = [ + patch( + "dimos.simulation.engines.mujoco_engine.MujocoEngine._resolve_xml_path", + return_value=Path("/fake/scene.xml"), + ), + patch( + "dimos.simulation.engines.mujoco_engine.mujoco.MjModel.from_xml_path", + return_value=fake_model, + ), + patch("dimos.simulation.engines.mujoco_engine.mujoco.MjData", return_value=fake_data), + patch("dimos.simulation.engines.mujoco_engine.build_joint_mappings", return_value=mappings), + ] + return patches + + +class TestSimMujocoAdapter: + """Tests for SimMujocoAdapter with and without gripper.""" + + @pytest.fixture + def adapter_with_gripper(self): + """SimMujocoAdapter with ARM_DOF arm joints + 1 gripper joint.""" + patches = _patch_mujoco_engine(ARM_DOF + 1) + for p in patches: + p.start() + try: + adapter = SimMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml", headless=True) + finally: + for p in patches: + p.stop() + return adapter + + @pytest.fixture + def adapter_no_gripper(self): + """SimMujocoAdapter with ARM_DOF arm joints, no gripper.""" + patches = _patch_mujoco_engine(ARM_DOF) + for p in patches: + p.start() + try: + adapter = SimMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml", headless=True) + finally: + for p in patches: + p.stop() + return adapter + + def test_address_required(self): + patches = _patch_mujoco_engine(ARM_DOF) + for p in patches: + p.start() + try: + with pytest.raises(ValueError, match="address"): + SimMujocoAdapter(dof=ARM_DOF, address=None) + finally: + for p in patches: + p.stop() + + def test_gripper_detected(self, adapter_with_gripper): + assert adapter_with_gripper._gripper_idx == ARM_DOF + + def test_no_gripper_when_dof_matches(self, adapter_no_gripper): + assert adapter_no_gripper._gripper_idx is None + + def test_read_gripper_position(self, adapter_with_gripper): + pos = adapter_with_gripper.read_gripper_position() + assert pos is not None + + def test_write_gripper_sets_target(self, adapter_with_gripper): + """Write a gripper position and verify the control target was set.""" + assert adapter_with_gripper.write_gripper_position(0.42) is True + target = adapter_with_gripper._engine._joint_position_targets[ARM_DOF] + assert target != 0.0, "write_gripper_position should update the control target" + + def test_read_gripper_position_no_gripper(self, adapter_no_gripper): + assert adapter_no_gripper.read_gripper_position() is None + + def test_write_gripper_position_no_gripper(self, adapter_no_gripper): + assert adapter_no_gripper.write_gripper_position(0.5) is False + + def test_write_gripper_does_not_clobber_arm(self, adapter_with_gripper): + """Gripper write must not overwrite arm joint targets.""" + engine = adapter_with_gripper._engine + for i in range(ARM_DOF): + engine._joint_position_targets[i] = float(i) + 1.0 + + adapter_with_gripper.write_gripper_position(0.0) + + for i in range(ARM_DOF): + assert engine._joint_position_targets[i] == pytest.approx(float(i) + 1.0) + + def test_read_joint_positions_excludes_gripper(self, adapter_with_gripper): + positions = adapter_with_gripper.read_joint_positions() + assert len(positions) == ARM_DOF + + def test_connect_and_disconnect(self, adapter_with_gripper): + with patch("dimos.simulation.engines.mujoco_engine.mujoco.mj_step"): + assert adapter_with_gripper.connect() is True + adapter_with_gripper.disconnect() + + def test_register(self): + registry = MagicMock() + register(registry) + registry.register.assert_called_once_with("sim_mujoco", SimMujocoAdapter) diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py deleted file mode 100644 index 951d4790e3..0000000000 --- a/dimos/simulation/manipulators/test_sim_module.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -import threading - -import pytest - -from dimos.protocol.rpc.spec import RPCSpec -from dimos.simulation.manipulators.sim_module import SimulationModule - - -class _DummyRPC(RPCSpec): - def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] - return None - - def start(self) -> None: - return None - - def stop(self) -> None: - return None - - -class _FakeBackend: - def __init__(self) -> None: - self._names = ["joint1", "joint2", "joint3"] - - def get_dof(self) -> int: - return len(self._names) - - def get_joint_names(self) -> list[str]: - return list(self._names) - - def read_joint_positions(self) -> list[float]: - return [0.1, 0.2, 0.3] - - def read_joint_velocities(self) -> list[float]: - return [0.0, 0.0, 0.0] - - def read_joint_efforts(self) -> list[float]: - return [0.0, 0.0, 0.0] - - def read_state(self) -> dict[str, int]: - return {"state": 1, "mode": 2} - - def read_error(self) -> tuple[int, str]: - return 0, "" - - def read_enabled(self) -> bool: - return True - - def disconnect(self) -> None: - return None - - -def _run_single_monitor_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] - def _wait_once(_: float) -> bool: - module._stop_event.set() - raise StopIteration - - monkeypatch.setattr(module._stop_event, "wait", _wait_once) - with pytest.raises(StopIteration): - module._monitor_loop() - - -def _run_single_control_iteration(module: SimulationModule, monkeypatch) -> None: # type: ignore[no-untyped-def] - def _wait_once(_: float) -> bool: - module._stop_event.set() - raise StopIteration - - monkeypatch.setattr(module._stop_event, "wait", _wait_once) - with pytest.raises(StopIteration): - module._control_loop() - - -def test_simulation_module_publishes_joint_state(monkeypatch) -> None: - module = SimulationModule( - engine="mujoco", - config_path=Path("."), - rpc_transport=_DummyRPC, - ) - module._backend = _FakeBackend() # type: ignore[assignment] - module._stop_event = threading.Event() - - joint_states: list[object] = [] - module.joint_state.subscribe(joint_states.append) - try: - _run_single_control_iteration(module, monkeypatch) - finally: - module.stop() - - assert len(joint_states) == 1 - assert joint_states[0].name == ["joint1", "joint2", "joint3"] - - -def test_simulation_module_publishes_robot_state(monkeypatch) -> None: - module = SimulationModule( - engine="mujoco", - config_path=Path("."), - rpc_transport=_DummyRPC, - ) - module._backend = _FakeBackend() # type: ignore[assignment] - module._stop_event = threading.Event() - - robot_states: list[object] = [] - module.robot_state.subscribe(robot_states.append) - try: - _run_single_monitor_iteration(module, monkeypatch) - finally: - module.stop() - - assert len(robot_states) == 1 - assert robot_states[0].state == 1 diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py deleted file mode 100644 index 2a8dd2d029..0000000000 --- a/dimos/simulation/sim_blueprints.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs.JointCommand import JointCommand -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.msgs.sensor_msgs.RobotState import RobotState -from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory -from dimos.simulation.manipulators.sim_module import SimulationModule -from dimos.utils.data import LfsPath - -xarm7_trajectory_sim = SimulationModule.blueprint( - engine="mujoco", - config_path=LfsPath("xarm7/scene.xml"), - headless=True, -).transports( - { - ("joint_state", JointState): LCMTransport("/xarm/joint_states", JointState), - ("robot_state", RobotState): LCMTransport("/xarm/robot_state", RobotState), - ("joint_position_command", JointCommand): LCMTransport( - "/xarm/joint_position_command", JointCommand - ), - ("trajectory", JointTrajectory): LCMTransport("/trajectory", JointTrajectory), - } -) - - -__all__ = [ - "xarm7_trajectory_sim", -] - -if __name__ == "__main__": - xarm7_trajectory_sim.build().loop() diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 6855ab62ca..d6367310de 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -18,6 +18,7 @@ from dimos.control.blueprints.teleop import ( coordinator_teleop_dual, coordinator_teleop_piper, + coordinator_teleop_xarm6, coordinator_teleop_xarm7, ) from dimos.core.blueprints import autoconnect @@ -68,6 +69,20 @@ ) +# Single XArm6 teleop: right controller -> xarm6 +teleop_quest_xarm6 = autoconnect( + ArmTeleopModule.blueprint(task_names={"right": "teleop_xarm"}), + coordinator_teleop_xarm6, +).transports( + { + ("right_controller_output", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("buttons", Buttons): LCMTransport("/teleop/buttons", Buttons), + } +) + + # Dual arm teleop: right -> piper, left -> xarm6 (TeleopIK) teleop_quest_dual = autoconnect( ArmTeleopModule.blueprint(task_names={"right": "teleop_piper", "left": "teleop_xarm"}), @@ -89,5 +104,6 @@ "teleop_quest_dual", "teleop_quest_piper", "teleop_quest_rerun", + "teleop_quest_xarm6", "teleop_quest_xarm7", ] From 645cb0d8c8540e3a646623f2cf651e6ce9433929 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 19:27:35 -0700 Subject: [PATCH 32/42] fix: address greptile review comments - native_module: buffer last 50 stderr lines in reader thread so crash report actually contains stderr output (was always empty before because the reader thread consumed and closed the stream) - unity module: wrap server socket in try/finally to prevent FD leak if bind()/listen() raises - unity module: drain stale send-queue messages at start of each new Unity connection to prevent delivering old-session data - unity module: read self._x/self._y under _state_lock in _on_terrain callback to ensure atomic read of the position pair --- dimos/core/native_module.py | 17 +++++------ dimos/simulation/unity/module.py | 50 +++++++++++++++++++------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 570c224d43..74471f34d5 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,6 +40,7 @@ class MyCppModule(NativeModule): from __future__ import annotations +import collections import enum import inspect import json @@ -131,9 +132,11 @@ class NativeModule(Module[_NativeConfig]): _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False + _last_stderr_lines: collections.deque[str] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + self._last_stderr_lines = collections.deque(maxlen=50) self._resolve_paths() @rpc @@ -216,15 +219,8 @@ def _watch_process(self) -> None: module_name = type(self).__name__ exe_name = Path(self.config.executable).name if self.config.executable else "unknown" - # Collect any remaining stderr for the crash report - last_stderr = "" - if self._process.stderr and not self._process.stderr.closed: - try: - remaining = self._process.stderr.read() - if remaining: - last_stderr = remaining.decode("utf-8", errors="replace").strip() - except Exception: - pass + # Use buffered stderr lines from the reader thread for the crash report. + last_stderr = "\n".join(self._last_stderr_lines) logger.error( f"Native process crashed: {module_name} ({exe_name})", @@ -246,10 +242,13 @@ def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: if stream is None: return log_fn = getattr(logger, level) + is_stderr = level == "warning" for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue + if is_stderr: + self._last_stderr_lines.append(line) if self.config.log_format == LogFormat.JSON: try: data = json.loads(line) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 324de377da..70a64e4c3e 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -411,8 +411,10 @@ def _on_terrain(self, cloud: PointCloud2) -> None: points, _ = cloud.as_numpy() if len(points) == 0: return - dx = points[:, 0] - self._x - dy = points[:, 1] - self._y + with self._state_lock: + cur_x, cur_y = self._x, self._y + dx = points[:, 0] - cur_x + dy = points[:, 1] - cur_y near = points[np.sqrt(dx * dx + dy * dy) < 0.5] if len(near) >= 10: with self._state_lock: @@ -426,28 +428,36 @@ def _unity_loop(self) -> None: server_sock.settimeout(2.0) logger.info(f"TCP server on :{self.config.unity_port}") - while self._running: - try: - conn, addr = server_sock.accept() - logger.info(f"Unity connected from {addr}") + try: + while self._running: try: - self._bridge_connection(conn) + conn, addr = server_sock.accept() + logger.info(f"Unity connected from {addr}") + try: + self._bridge_connection(conn) + except Exception as e: + logger.info(f"Unity connection ended: {e}") + finally: + with self._state_lock: + self._unity_connected = False + conn.close() + except TimeoutError: + continue except Exception as e: - logger.info(f"Unity connection ended: {e}") - finally: - with self._state_lock: - self._unity_connected = False - conn.close() - except TimeoutError: - continue - except Exception as e: - if self._running: - logger.warning(f"TCP server error: {e}") - time.sleep(1.0) - - server_sock.close() + if self._running: + logger.warning(f"TCP server error: {e}") + time.sleep(1.0) + finally: + server_sock.close() def _bridge_connection(self, sock: socket.socket) -> None: + # Drain stale messages from a previous session. + while not self._send_queue.empty(): + try: + self._send_queue.get_nowait() + except Empty: + break + sock.settimeout(None) with self._state_lock: self._unity_connected = True From 5d461c6359c036effc8bdd631f1ee62e82beefaa Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 19:40:30 -0700 Subject: [PATCH 33/42] fix: address all paul-review issues on unity simulator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Must-fix: - _running changed from bare bool to threading.Event for proper cross-thread visibility (consistent with _unity_ready and halt) - _sim_loop now holds _state_lock during position reads/writes so _on_terrain sees atomic position pairs - _unity_sender logs exceptions before halting instead of swallowing - Camera intrinsics unified: _publish_camera_info and rerun_static_pinhole now share the same constants (_CAM_FX/FY/CX/CY from 120° HFOV) Should-fix: - import signal and import cv2 moved to top-level (no inline imports) - stop() catches subprocess.TimeoutExpired specifically, logs SIGKILL escalation - Queue drain loop simplified: while True + Empty break (no .empty()) - server_sock wrapped in try/finally from creation (not just the loop) - unity_host default changed from 0.0.0.0 to 127.0.0.1 (loopback) - Bridge connection uses 30s read timeout to detect hung Unity sessions Nits: - Tests use try/finally for thread cleanup - Pickle test checks _running.is_set() not bare bool - Added edge case tests: empty/truncated/garbage pointcloud, truncated compressed image, pose_stamped round-trip - ROS1 deserializers log at DEBUG level on parse failure instead of silently returning None --- dimos/models/segmentation/edge_tam.py | 1 - dimos/models/vl/create.py | 2 +- dimos/simulation/unity/module.py | 158 ++++++++++++----------- dimos/simulation/unity/test_unity_sim.py | 87 +++++++++---- dimos/utils/ros1.py | 5 + 5 files changed, 152 insertions(+), 101 deletions(-) diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index 61b06d5efd..91cdec661d 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -14,7 +14,6 @@ from collections.abc import Generator from contextlib import contextmanager -import os from pathlib import Path import shutil import tempfile diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 7fe5a0dcb2..b39159c54f 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,7 +1,7 @@ from typing import Any -from dimos.models.vl.types import VlModelName from dimos.models.vl.base import VlModel +from dimos.models.vl.types import VlModelName def create(name: VlModelName) -> VlModel[Any]: diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 70a64e4c3e..1f02395ba4 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -35,6 +35,7 @@ from pathlib import Path import platform from queue import Empty, Queue +import signal import socket import struct import subprocess @@ -42,6 +43,7 @@ import time from typing import Any +import cv2 import numpy as np from pydantic import Field from reactivex.disposable import Disposable @@ -74,6 +76,11 @@ _SUPPORTED_SYSTEMS = {"Linux"} _SUPPORTED_ARCHS = {"x86_64", "AMD64"} +# Read timeout for the Unity TCP connection (seconds). If Unity stops +# sending data for longer than this the bridge treats it as a hung +# connection and drops it. +_BRIDGE_READ_TIMEOUT = 30.0 + # TCP protocol helpers @@ -157,7 +164,9 @@ class UnityBridgeConfig(ModuleConfig): unity_connect_timeout: float = 30.0 # TCP server settings (we listen; Unity connects to us). - unity_host: str = "0.0.0.0" + # Default to loopback — set to "0.0.0.0" explicitly if Unity runs + # on a different machine. + unity_host: str = "127.0.0.1" unity_port: int = 10000 # Run Unity with no visible window (set -batchmode -nographics). @@ -180,6 +189,22 @@ class UnityBridgeConfig(ModuleConfig): sim_rate: float = 200.0 +# Camera intrinsics constants. +# +# The Unity camera produces a 360° cylindrical panorama (1920×640). +# A true pinhole model cannot represent this, so we approximate with +# a 120° horizontal FOV window. Both CameraInfo and the Rerun static +# pinhole use the SAME focal length so downstream consumers see +# consistent intrinsics. +_CAM_WIDTH = 1920 +_CAM_HEIGHT = 640 +_CAM_HFOV_RAD = math.radians(120.0) +_CAM_FX = (_CAM_WIDTH / 2.0) / math.tan(_CAM_HFOV_RAD / 2.0) +_CAM_FY = _CAM_FX +_CAM_CX = _CAM_WIDTH / 2.0 +_CAM_CY = _CAM_HEIGHT / 2.0 + + # Module @@ -206,27 +231,14 @@ class UnityBridgeModule(Module[UnityBridgeConfig]): semantic_image: Out[Image] camera_info: Out[CameraInfo] - # Rerun static config for 3D camera projection — use this when building - # your rerun_config so the panoramic image renders correctly in 3D. - # - # Usage: - # rerun_config = { - # "static": {"world/color_image": UnityBridgeModule.rerun_static_pinhole}, - # "visual_override": {"world/camera_info": UnityBridgeModule.rerun_suppress_camera_info}, - # } @staticmethod def rerun_static_pinhole(rr: Any) -> list[Any]: """Static Pinhole + Transform3D for the Unity panoramic camera.""" - width, height = 1920, 640 - hfov_rad = math.radians(120.0) - fx = (width / 2.0) / math.tan(hfov_rad / 2.0) - fy = fx - cx, cy = width / 2.0, height / 2.0 return [ rr.Pinhole( - resolution=[width, height], - focal_length=[fx, fy], - principal_point=[cx, cy], + resolution=[_CAM_WIDTH, _CAM_HEIGHT], + focal_length=[_CAM_FX, _CAM_FY], + principal_point=[_CAM_CX, _CAM_CY], camera_xyz=rr.ViewCoordinates.RDF, ), rr.Transform3D( @@ -255,7 +267,7 @@ def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] self._yaw_rate = 0.0 self._cmd_lock = threading.Lock() self._state_lock = threading.Lock() - self._running = False + self._running = threading.Event() self._sim_thread: threading.Thread | None = None self._unity_thread: threading.Thread | None = None self._unity_connected = False @@ -274,6 +286,7 @@ def __getstate__(self) -> dict[str, Any]: # type: ignore[override] "_unity_process", "_send_queue", "_unity_ready", + "_running", ): state.pop(key, None) return state @@ -287,7 +300,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self._unity_process = None self._send_queue = Queue() self._unity_ready = threading.Event() - self._running = False + self._running = threading.Event() self._binary_path = self._resolve_binary() @rpc @@ -295,7 +308,7 @@ def start(self) -> None: super().start() self._disposables.add(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) self._disposables.add(Disposable(self.terrain_map.subscribe(self._on_terrain))) - self._running = True + self._running.set() self._sim_thread = threading.Thread(target=self._sim_loop, daemon=True) self._sim_thread.start() self._unity_thread = threading.Thread(target=self._unity_loop, daemon=True) @@ -304,19 +317,20 @@ def start(self) -> None: @rpc def stop(self) -> None: - self._running = False + self._running.clear() if self._sim_thread: self._sim_thread.join(timeout=2.0) if self._unity_thread: self._unity_thread.join(timeout=2.0) if self._unity_process is not None and self._unity_process.poll() is None: - import signal as _sig - logger.info(f"Stopping Unity (pid={self._unity_process.pid})") - self._unity_process.send_signal(_sig.SIGTERM) + self._unity_process.send_signal(signal.SIGTERM) try: self._unity_process.wait(timeout=5) - except Exception: + except subprocess.TimeoutExpired: + logger.warning( + f"Unity pid={self._unity_process.pid} did not exit after SIGTERM, killing" + ) self._unity_process.kill() self._unity_process = None super().stop() @@ -422,21 +436,21 @@ def _on_terrain(self, cloud: PointCloud2) -> None: def _unity_loop(self) -> None: server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_sock.bind((self.config.unity_host, self.config.unity_port)) - server_sock.listen(1) - server_sock.settimeout(2.0) - logger.info(f"TCP server on :{self.config.unity_port}") - try: - while self._running: + server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_sock.bind((self.config.unity_host, self.config.unity_port)) + server_sock.listen(1) + server_sock.settimeout(2.0) + logger.info(f"TCP server on :{self.config.unity_port}") + + while self._running.is_set(): try: conn, addr = server_sock.accept() logger.info(f"Unity connected from {addr}") try: self._bridge_connection(conn) except Exception as e: - logger.info(f"Unity connection ended: {e}") + logger.warning(f"Unity connection ended: {e}") finally: with self._state_lock: self._unity_connected = False @@ -444,7 +458,7 @@ def _unity_loop(self) -> None: except TimeoutError: continue except Exception as e: - if self._running: + if self._running.is_set(): logger.warning(f"TCP server error: {e}") time.sleep(1.0) finally: @@ -452,13 +466,13 @@ def _unity_loop(self) -> None: def _bridge_connection(self, sock: socket.socket) -> None: # Drain stale messages from a previous session. - while not self._send_queue.empty(): + while True: try: self._send_queue.get_nowait() except Empty: break - sock.settimeout(None) + sock.settimeout(_BRIDGE_READ_TIMEOUT) with self._state_lock: self._unity_connected = True self._unity_ready.set() @@ -477,8 +491,11 @@ def _bridge_connection(self, sock: socket.socket) -> None: sender.start() try: - while self._running and not halt.is_set(): - dest, data = _read_tcp_message(sock) + while self._running.is_set() and not halt.is_set(): + try: + dest, data = _read_tcp_message(sock) + except TimeoutError: + continue if dest == "": continue elif dest.startswith("__"): @@ -501,7 +518,8 @@ def _unity_sender(self, sock: socket.socket, halt: threading.Event) -> None: _write_tcp_message(sock, dest, data) except Empty: continue - except Exception: + except Exception as e: + logger.warning(f"Unity sender error: {e}") halt.set() def _handle_syscommand(self, dest: str, data: bytes) -> None: @@ -540,8 +558,6 @@ def _handle_unity_message(self, topic: str, data: bytes) -> None: if img_result is not None: img_bytes, _fmt, _frame_id, ts = img_result try: - import cv2 - decoded = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR) if decoded is not None: img = Image.from_numpy(decoded, frame_id="camera", ts=ts) @@ -555,22 +571,17 @@ def _handle_unity_message(self, topic: str, data: bytes) -> None: logger.warning(f"Image decode failed ({topic}): {e}") def _publish_camera_info(self, width: int, height: int, ts: float) -> None: - # NOTE: The Unity camera is a 360-degree cylindrical panorama (1920x640). - # CameraInfo assumes a pinhole model, so this is an approximation. - # The Rerun static pinhole (rerun_static_pinhole) uses a different focal - # length tuned for a 120-deg FOV window because Rerun has no cylindrical - # projection support. These intentionally differ. - fx = fy = height / 2.0 - cx, cy = width / 2.0, height / 2.0 + # Use the same intrinsics as rerun_static_pinhole (120° HFOV pinhole + # approximation of the cylindrical panorama). self.camera_info.publish( CameraInfo( height=height, width=width, distortion_model="plumb_bob", D=[0.0, 0.0, 0.0, 0.0, 0.0], - K=[fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + K=[_CAM_FX, 0.0, _CAM_CX, 0.0, _CAM_FY, _CAM_CY, 0.0, 0.0, 1.0], R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - P=[fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + P=[_CAM_FX, 0.0, _CAM_CX, 0.0, 0.0, _CAM_FY, _CAM_CY, 0.0, 0.0, 0.0, 1.0, 0.0], frame_id="camera", ts=ts, ) @@ -585,29 +596,32 @@ def _send_to_unity(self, topic: str, data: bytes) -> None: def _sim_loop(self) -> None: dt = 1.0 / self.config.sim_rate - while self._running: + while self._running.is_set(): t0 = time.monotonic() with self._cmd_lock: fwd, left, yaw_rate = self._fwd_speed, self._left_speed, self._yaw_rate - prev_z = self._z + with self._state_lock: + prev_z = self._z - self._yaw += dt * yaw_rate - if self._yaw > PI: - self._yaw -= 2 * PI - elif self._yaw < -PI: - self._yaw += 2 * PI + self._yaw += dt * yaw_rate + if self._yaw > PI: + self._yaw -= 2 * PI + elif self._yaw < -PI: + self._yaw += 2 * PI - cy, sy = math.cos(self._yaw), math.sin(self._yaw) - self._x += dt * cy * fwd - dt * sy * left - self._y += dt * sy * fwd + dt * cy * left - with self._state_lock: - terrain_z = self._terrain_z - self._z = terrain_z + self.config.vehicle_height + cy, sy = math.cos(self._yaw), math.sin(self._yaw) + self._x += dt * cy * fwd - dt * sy * left + self._y += dt * sy * fwd + dt * cy * left + self._z = self._terrain_z + self.config.vehicle_height + + x, y, z = self._x, self._y, self._z + yaw = self._yaw + roll, pitch = self._roll, self._pitch now = time.time() - quat = Quaternion.from_euler(Vector3(self._roll, self._pitch, self._yaw)) + quat = Quaternion.from_euler(Vector3(roll, pitch, yaw)) self.odometry.publish( Odometry( @@ -615,11 +629,11 @@ def _sim_loop(self) -> None: frame_id="map", child_frame_id="sensor", pose=Pose( - position=[self._x, self._y, self._z], + position=[x, y, z], orientation=[quat.x, quat.y, quat.z, quat.w], ), twist=Twist( - linear=[fwd, left, (self._z - prev_z) * self.config.sim_rate], + linear=[fwd, left, (z - prev_z) * self.config.sim_rate], angular=[0.0, 0.0, yaw_rate], ), ) @@ -627,7 +641,7 @@ def _sim_loop(self) -> None: self.tf.publish( Transform( - translation=Vector3(self._x, self._y, self._z), + translation=Vector3(x, y, z), rotation=quat, frame_id="map", child_frame_id="sensor", @@ -647,15 +661,7 @@ def _sim_loop(self) -> None: if unity_connected: self._send_to_unity( "/unity_sim/set_model_state", - serialize_pose_stamped( - self._x, - self._y, - self._z, - quat.x, - quat.y, - quat.z, - quat.w, - ), + serialize_pose_stamped(x, y, z, quat.x, quat.y, quat.z, quat.w), ) sleep_for = dt - (time.monotonic() - t0) diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index aee84199b9..9eb57ef933 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -168,7 +168,7 @@ def test_module_survives_pickle(self): m = UnityBridgeModule(unity_binary="") m2 = pickle.loads(pickle.dumps(m)) assert hasattr(m2, "_cmd_lock") - assert m2._running is False + assert not m2._running.is_set() m.stop() m2.stop() @@ -186,6 +186,42 @@ def test_pointcloud2_round_trip(self): np.testing.assert_allclose(decoded_pts, pts, atol=1e-5) assert frame_id == "map" + def test_pointcloud2_empty(self): + pts = np.zeros((0, 3), dtype=np.float32) + data = _build_ros1_pointcloud2(pts) + result = deserialize_pointcloud2(data) + assert result is not None + decoded_pts, _, _ = result + assert len(decoded_pts) == 0 + + def test_pointcloud2_truncated(self): + pts = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) + data = _build_ros1_pointcloud2(pts) + assert deserialize_pointcloud2(data[:10]) is None + + def test_pointcloud2_garbage(self): + assert deserialize_pointcloud2(b"\xff\x00\x01\x02") is None + + def test_compressed_image_truncated(self): + from dimos.utils.ros1 import deserialize_compressed_image + + assert deserialize_compressed_image(b"\x03\x00") is None + + def test_serialize_pose_stamped_round_trip(self): + from dimos.utils.ros1 import ROS1Reader, read_header, serialize_pose_stamped + + data = serialize_pose_stamped(1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0, frame_id="odom") + r = ROS1Reader(data) + header = read_header(r) + assert header.frame_id == "odom" + assert r.f64() == pytest.approx(1.0) + assert r.f64() == pytest.approx(2.0) + assert r.f64() == pytest.approx(3.0) + assert r.f64() == pytest.approx(0.0) # qx + assert r.f64() == pytest.approx(0.0) # qy + assert r.f64() == pytest.approx(0.0) # qz + assert r.f64() == pytest.approx(1.0) # qw + # TCP Bridge — needs sockets, ~1s, runs everywhere @@ -197,25 +233,26 @@ def test_handshake_and_data_flow(self): m = UnityBridgeModule(unity_binary="", unity_port=port) ts = _wire(m) - m._running = True + m._running.set() m._unity_thread = threading.Thread(target=m._unity_loop, daemon=True) m._unity_thread.start() time.sleep(0.3) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", port)) - - dest, data = _recv_tcp(sock) - assert dest == "__handshake" + try: + sock.connect(("127.0.0.1", port)) - pts = np.array([[10.0, 20.0, 30.0]], dtype=np.float32) - _send_tcp(sock, "/registered_scan", _build_ros1_pointcloud2(pts)) - time.sleep(0.3) + dest, data = _recv_tcp(sock) + assert dest == "__handshake" - m._running = False - sock.close() - m._unity_thread.join(timeout=3) - m.stop() + pts = np.array([[10.0, 20.0, 30.0]], dtype=np.float32) + _send_tcp(sock, "/registered_scan", _build_ros1_pointcloud2(pts)) + time.sleep(0.3) + finally: + m._running.clear() + sock.close() + m._unity_thread.join(timeout=3) + m.stop() assert len(ts["registered_scan"]._messages) >= 1 received_pts, _ = ts["registered_scan"]._messages[0].as_numpy() @@ -230,13 +267,15 @@ def test_odometry_published(self): m = UnityBridgeModule(unity_binary="", sim_rate=100.0) ts = _wire(m) - m._running = True + m._running.set() m._sim_thread = threading.Thread(target=m._sim_loop, daemon=True) m._sim_thread.start() - time.sleep(0.2) - m._running = False - m._sim_thread.join(timeout=2) - m.stop() + try: + time.sleep(0.2) + finally: + m._running.clear() + m._sim_thread.join(timeout=2) + m.stop() assert len(ts["odometry"]._messages) > 5 assert ts["odometry"]._messages[0].frame_id == "map" @@ -246,13 +285,15 @@ def test_cmd_vel_moves_robot(self): ts = _wire(m) m._on_cmd_vel(Twist(linear=[1.0, 0.0, 0.0], angular=[0.0, 0.0, 0.0])) - m._running = True + m._running.set() m._sim_thread = threading.Thread(target=m._sim_loop, daemon=True) m._sim_thread.start() - time.sleep(1.0) - m._running = False - m._sim_thread.join(timeout=2) - m.stop() + try: + time.sleep(1.0) + finally: + m._running.clear() + m._sim_thread.join(timeout=2) + m.stop() last_odom = ts["odometry"]._messages[-1] assert last_odom.x > 0.5 diff --git a/dimos/utils/ros1.py b/dimos/utils/ros1.py index 9053ce20bc..3cac9ef67e 100644 --- a/dimos/utils/ros1.py +++ b/dimos/utils/ros1.py @@ -36,11 +36,14 @@ from __future__ import annotations from dataclasses import dataclass +import logging import struct import time import numpy as np +logger = logging.getLogger(__name__) + # Low-level readers @@ -255,6 +258,7 @@ def deserialize_pointcloud2(data: bytes) -> tuple[np.ndarray, str, float] | None return points, header.frame_id, header.stamp except Exception: + logger.debug("Failed to deserialize PointCloud2", exc_info=True) return None @@ -275,6 +279,7 @@ def deserialize_compressed_image(data: bytes) -> tuple[bytes, str, str, float] | img_data = r.raw(img_len) return img_data, fmt, header.frame_id, header.stamp except Exception: + logger.debug("Failed to deserialize CompressedImage", exc_info=True) return None From 002a419434d9161a73aa23e0caa9ed24995eeeb6 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 22:07:48 -0700 Subject: [PATCH 34/42] fix: resolve merge conflicts + address Paul's review comments Merge conflict resolution: - Resolve typing_extensions conflicts in resource.py, blueprints.py, native_module.py, create.py, all_blueprints.py from dev merge Paul's review comments on ros1.py: - Use setup_logger() instead of logging.getLogger(__name__) - Use logger.exception() instead of logger.debug() in except blocks Paul's review comments on module.py: - Move platform validation constants into _validate_platform() - Add kwargs: Any type annotation to UnityBridgeModule.__init__ --- dimos/core/blueprints.py | 4 ---- dimos/core/native_module.py | 6 ------ dimos/core/resource.py | 4 ---- dimos/models/vl/create.py | 6 ------ dimos/robot/all_blueprints.py | 6 ------ dimos/simulation/unity/module.py | 11 ++++++----- dimos/utils/ros1.py | 9 +++++---- 7 files changed, 11 insertions(+), 35 deletions(-) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 134e561845..0b7403a11b 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -37,11 +37,7 @@ if sys.version_info >= (3, 11): from typing import Self else: -<<<<<<< HEAD - from typing import Any as Self -======= from typing_extensions import Self ->>>>>>> origin/dev logger = setup_logger() diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 446d5bc89d..74471f34d5 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,10 +40,7 @@ class MyCppModule(NativeModule): from __future__ import annotations -<<<<<<< HEAD import collections -======= ->>>>>>> origin/dev import enum import inspect import json @@ -139,10 +136,7 @@ class NativeModule(Module[_NativeConfig]): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) -<<<<<<< HEAD self._last_stderr_lines = collections.deque(maxlen=50) -======= ->>>>>>> origin/dev self._resolve_paths() @rpc diff --git a/dimos/core/resource.py b/dimos/core/resource.py index 032168456d..a4c008b806 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -15,9 +15,6 @@ from __future__ import annotations from abc import abstractmethod -<<<<<<< HEAD -from typing import TYPE_CHECKING, Self -======= import sys from typing import TYPE_CHECKING @@ -25,7 +22,6 @@ from typing import Self else: from typing_extensions import Self ->>>>>>> origin/dev if TYPE_CHECKING: from types import TracebackType diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index a6ee1070bd..b39159c54f 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,15 +1,9 @@ from typing import Any -from dimos.models.vl.types import VlModelName from dimos.models.vl.base import VlModel -<<<<<<< HEAD from dimos.models.vl.types import VlModelName -======= - - ->>>>>>> origin/dev def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. match name: diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index fb0bc4a878..bcf96bf184 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -89,10 +89,7 @@ "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", "unitree-go2-webrtc-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_keyboard_teleop:unitree_go2_webrtc_keyboard_teleop", -<<<<<<< HEAD "unity-sim": "dimos.simulation.unity.blueprint:unity_sim", -======= ->>>>>>> origin/dev "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", @@ -170,10 +167,7 @@ "twist-teleop-module": "dimos.teleop.quest.quest_extensions", "unitree-g1-skill-container": "dimos.robot.unitree.g1.skill_container", "unitree-skill-container": "dimos.robot.unitree.unitree_skill_container", -<<<<<<< HEAD "unity-bridge-module": "dimos.simulation.unity.module", -======= ->>>>>>> origin/dev "vlm-agent": "dimos.agents.vlm_agent", "vlm-stream-tester": "dimos.agents.vlm_stream_tester", "voxel-grid-mapper": "dimos.mapping.voxels", diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 1f02395ba4..537db48426 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -73,8 +73,6 @@ # LFS data asset name for the Unity sim binary _LFS_ASSET = "unity_sim_x86" -_SUPPORTED_SYSTEMS = {"Linux"} -_SUPPORTED_ARCHS = {"x86_64", "AMD64"} # Read timeout for the Unity TCP connection (seconds). If Unity stops # sending data for longer than this the bridge treats it as a hung @@ -128,17 +126,20 @@ def _write_tcp_command(sock: socket.socket, command: str, params: dict[str, Any] def _validate_platform() -> None: """Raise if the current platform can't run the Unity x86_64 binary.""" + supported_systems = {"Linux"} + supported_archs = {"x86_64", "AMD64"} + system = platform.system() arch = platform.machine() - if system not in _SUPPORTED_SYSTEMS: + if system not in supported_systems: raise RuntimeError( f"Unity simulator requires Linux x86_64 but running on {system} {arch}. " f"macOS and Windows are not supported (the binary is a Linux ELF executable). " f"Use a Linux VM, Docker, or WSL2." ) - if arch not in _SUPPORTED_ARCHS: + if arch not in supported_archs: raise RuntimeError( f"Unity simulator requires x86_64 but running on {arch}. " f"ARM64 Linux is not supported. Use an x86_64 machine or emulation layer." @@ -253,7 +254,7 @@ def rerun_suppress_camera_info(_: Any) -> None: """Suppress CameraInfo logging — the static pinhole handles 3D projection.""" return None - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._x = self.config.init_x self._y = self.config.init_y diff --git a/dimos/utils/ros1.py b/dimos/utils/ros1.py index 3cac9ef67e..dd1f9b9224 100644 --- a/dimos/utils/ros1.py +++ b/dimos/utils/ros1.py @@ -36,13 +36,14 @@ from __future__ import annotations from dataclasses import dataclass -import logging import struct import time import numpy as np -logger = logging.getLogger(__name__) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() # Low-level readers @@ -258,7 +259,7 @@ def deserialize_pointcloud2(data: bytes) -> tuple[np.ndarray, str, float] | None return points, header.frame_id, header.stamp except Exception: - logger.debug("Failed to deserialize PointCloud2", exc_info=True) + logger.exception("Failed to deserialize PointCloud2") return None @@ -279,7 +280,7 @@ def deserialize_compressed_image(data: bytes) -> tuple[bytes, str, str, float] | img_data = r.raw(img_len) return img_data, fmt, header.frame_id, header.stamp except Exception: - logger.debug("Failed to deserialize CompressedImage", exc_info=True) + logger.exception("Failed to deserialize CompressedImage") return None From 182cf28e49d8da203db1d19079d25c3f0a0f28e5 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sun, 22 Mar 2026 14:31:26 -0700 Subject: [PATCH 35/42] refactor: remove unnecessary __getstate__/__setstate__ from UnityBridgeModule Modules are instantiated directly in worker subprocesses and never pickled across the process boundary. The pickle state methods were only exercised by a dedicated test, not by actual runtime. --- dimos/simulation/unity/module.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 537db48426..98aca868df 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -277,33 +277,6 @@ def __init__(self, **kwargs: Any) -> None: self._send_queue: Queue[tuple[str, bytes]] = Queue() self._binary_path = self._resolve_binary() - def __getstate__(self) -> dict[str, Any]: # type: ignore[override] - state: dict[str, Any] = super().__getstate__() # type: ignore[no-untyped-call] - for key in ( - "_cmd_lock", - "_state_lock", - "_sim_thread", - "_unity_thread", - "_unity_process", - "_send_queue", - "_unity_ready", - "_running", - ): - state.pop(key, None) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - super().__setstate__(state) - self._cmd_lock = threading.Lock() - self._state_lock = threading.Lock() - self._sim_thread = None - self._unity_thread = None - self._unity_process = None - self._send_queue = Queue() - self._unity_ready = threading.Event() - self._running = threading.Event() - self._binary_path = self._resolve_binary() - @rpc def start(self) -> None: super().start() From 9b609bde35c3fbf4b12474f51d300efbafcd5f2c Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 23:11:11 -0700 Subject: [PATCH 36/42] fix(unity): launch Unity in thread to avoid blocking start() _launch_unity() blocks up to 30s waiting for Unity to connect. This stalls the entire blueprint build. Move to a daemon thread. Revert: git revert HEAD --- dimos/simulation/unity/module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 98aca868df..d298223d08 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -287,7 +287,9 @@ def start(self) -> None: self._sim_thread.start() self._unity_thread = threading.Thread(target=self._unity_loop, daemon=True) self._unity_thread.start() - self._launch_unity() + # Launch Unity in a thread to avoid blocking start() for up to + # unity_connect_timeout seconds (default 30s). + threading.Thread(target=self._launch_unity, daemon=True).start() @rpc def stop(self) -> None: From 47d99da785e8940a534e49c38911e33ebe8fce1c Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 23:11:28 -0700 Subject: [PATCH 37/42] fix(unity): pipe Unity stderr to logger instead of discarding Unity stdout/stderr were both sent to DEVNULL, making crashes undiagnosable. Now stderr is piped to a reader thread that logs each line at warning level. Revert: git revert HEAD --- dimos/simulation/unity/module.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index d298223d08..51a434df11 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -371,8 +371,20 @@ def _launch_unity(self) -> None: cwd=str(binary_path.parent), env=env, stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, + stderr=subprocess.PIPE, ) + + # Read Unity stderr in a background thread for diagnostics. + def _drain_stderr() -> None: + assert self._unity_process is not None + assert self._unity_process.stderr is not None + for raw in self._unity_process.stderr: + line = raw.decode("utf-8", errors="replace").rstrip() + if line: + logger.warning(f"Unity stderr: {line}") + self._unity_process.stderr.close() + + threading.Thread(target=_drain_stderr, daemon=True).start() logger.info(f"Unity pid={self._unity_process.pid}, waiting for TCP connection...") if self._unity_ready.wait(timeout=self.config.unity_connect_timeout): From 9f0f7b932ba40c2dcf86e6cd805f99eb30c2cb72 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sat, 21 Mar 2026 23:11:55 -0700 Subject: [PATCH 38/42] fix(unity): clear _unity_ready on disconnect _unity_ready was a one-shot Event that was never cleared. On reconnect the state of _unity_connected and _unity_ready was inconsistent. Now cleared when a connection ends. Revert: git revert HEAD --- dimos/simulation/unity/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 51a434df11..3c167d05c3 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -442,6 +442,7 @@ def _unity_loop(self) -> None: finally: with self._state_lock: self._unity_connected = False + self._unity_ready.clear() conn.close() except TimeoutError: continue From 258b0ccdb5e11a0bdf91d0a27fa35ec25d9a1859 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sun, 22 Mar 2026 14:36:20 -0700 Subject: [PATCH 39/42] test: remove pickle test (follows __getstate__/__setstate__ removal) --- dimos/simulation/unity/test_unity_sim.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index 9eb57ef933..7ac9c49296 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -24,7 +24,6 @@ """ import os -import pickle import platform import socket import struct @@ -160,19 +159,6 @@ def test_rejects_unsupported_platform(self): _validate_platform() -# Pickle — fast, runs everywhere - - -class TestPickle: - def test_module_survives_pickle(self): - m = UnityBridgeModule(unity_binary="") - m2 = pickle.loads(pickle.dumps(m)) - assert hasattr(m2, "_cmd_lock") - assert not m2._running.is_set() - m.stop() - m2.stop() - - # ROS1 Deserialization — fast, runs everywhere From 66e1819078cc50434ef5193be49475da717cfb5e Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sun, 22 Mar 2026 15:15:33 -0700 Subject: [PATCH 40/42] fix(unity): thread safety for _unity_process and stderr drain - Guard _drain_stderr against process being killed by stop() - Protect _unity_process reads/writes with _state_lock - Remove redundant _unity_connected=False in _bridge_connection --- dimos/simulation/unity/module.py | 44 +++++++++++++++++--------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index 3c167d05c3..e1354107e5 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -298,17 +298,17 @@ def stop(self) -> None: self._sim_thread.join(timeout=2.0) if self._unity_thread: self._unity_thread.join(timeout=2.0) - if self._unity_process is not None and self._unity_process.poll() is None: - logger.info(f"Stopping Unity (pid={self._unity_process.pid})") - self._unity_process.send_signal(signal.SIGTERM) + with self._state_lock: + proc = self._unity_process + self._unity_process = None + if proc is not None and proc.poll() is None: + logger.info(f"Stopping Unity (pid={proc.pid})") + proc.send_signal(signal.SIGTERM) try: - self._unity_process.wait(timeout=5) + proc.wait(timeout=5) except subprocess.TimeoutExpired: - logger.warning( - f"Unity pid={self._unity_process.pid} did not exit after SIGTERM, killing" - ) - self._unity_process.kill() - self._unity_process = None + logger.warning(f"Unity pid={proc.pid} did not exit after SIGTERM, killing") + proc.kill() super().stop() def _resolve_binary(self) -> Path | None: @@ -366,23 +366,29 @@ def _launch_unity(self) -> None: if "DISPLAY" not in env and not self.config.headless: env["DISPLAY"] = ":0" - self._unity_process = subprocess.Popen( + proc = subprocess.Popen( cmd, cwd=str(binary_path.parent), env=env, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, ) + with self._state_lock: + self._unity_process = proc # Read Unity stderr in a background thread for diagnostics. + proc = self._unity_process # capture ref — stop() may clear self._unity_process + def _drain_stderr() -> None: - assert self._unity_process is not None - assert self._unity_process.stderr is not None - for raw in self._unity_process.stderr: - line = raw.decode("utf-8", errors="replace").rstrip() - if line: - logger.warning(f"Unity stderr: {line}") - self._unity_process.stderr.close() + try: + assert proc.stderr is not None + for raw in proc.stderr: + line = raw.decode("utf-8", errors="replace").rstrip() + if line: + logger.warning(f"Unity stderr: {line}") + proc.stderr.close() + except (OSError, ValueError): + pass # process killed or pipe closed by stop() threading.Thread(target=_drain_stderr, daemon=True).start() logger.info(f"Unity pid={self._unity_process.pid}, waiting for TCP connection...") @@ -391,7 +397,7 @@ def _drain_stderr() -> None: logger.info("Unity connected") else: # Check if process died - rc = self._unity_process.poll() + rc = proc.poll() if rc is not None: logger.error( f"Unity process exited with code {rc} before connecting. " @@ -494,8 +500,6 @@ def _bridge_connection(self, sock: socket.socket) -> None: finally: halt.set() sender.join(timeout=2.0) - with self._state_lock: - self._unity_connected = False def _unity_sender(self, sock: socket.socket, halt: threading.Event) -> None: while not halt.is_set(): From d6bf9fb8892a7321c7eace8c5a2606ce32c3a150 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sun, 22 Mar 2026 15:16:30 -0700 Subject: [PATCH 41/42] fix(lfs): repack unity_sim_x86 tarball with correct directory name The tarball contained cmu_unity_sim_x86/ as the top-level directory but get_data() expects unity_sim_x86/ (matching the asset name). Repacked so extraction works without symlinks. --- data/.lfs/unity_sim_x86.tar.gz | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/.lfs/unity_sim_x86.tar.gz b/data/.lfs/unity_sim_x86.tar.gz index 00212578a9..d070563158 100644 --- a/data/.lfs/unity_sim_x86.tar.gz +++ b/data/.lfs/unity_sim_x86.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b02bb692abceedb05e5d85efc0f9c1b1f0d605b4ae011c1a98d35c64036abc11 -size 133299059 +oid sha256:d61381e42c63919e6d4bd3ef9f36f1b3b1a60cc61ad214fa308625cd67dcd100 +size 133295482 From be5666cb0a31ba5e3cbc179354922429d1ef0a84 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Sun, 22 Mar 2026 15:50:54 -0700 Subject: [PATCH 42/42] fix default environment --- data/.lfs/unity_sim_x86.tar.gz | 4 ++-- dimos/robot/unitree/go2/connection.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/data/.lfs/unity_sim_x86.tar.gz b/data/.lfs/unity_sim_x86.tar.gz index d070563158..15c06301fc 100644 --- a/data/.lfs/unity_sim_x86.tar.gz +++ b/data/.lfs/unity_sim_x86.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d61381e42c63919e6d4bd3ef9f36f1b3b1a60cc61ad214fa308625cd67dcd100 -size 133295482 +oid sha256:d4ce5b93751657cc991c4242c227627ec3bbc0263085312e602eae264652d3ac +size 581676645 diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index db3ecb40fc..5123dc9a31 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -229,7 +229,8 @@ def record(self, recording_name: str) -> None: @rpc def start(self) -> None: super().start() - + if not hasattr(self, "connection"): + return self.connection.start() def onimage(image: Image) -> None: