From f82d270e6618f20f0d86349d6c6536fc34c6a483 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Mon, 11 May 2026 08:20:29 +0300 Subject: [PATCH] feat: use individual module connections --- dimos/agents/skills/person_follow.py | 4 +- dimos/core/coordination/blueprints.py | 95 ++++ dimos/core/coordination/test_blueprints.py | 140 +++++- dimos/core/global_config.py | 12 + dimos/core/module.py | 7 + dimos/core/test_global_config.py | 27 +- dimos/e2e_tests/dimos_cli_call.py | 2 +- .../security_demo/security_module.py | 4 +- dimos/memory2/test_visualizer.py | 4 +- dimos/perception/detection/conftest.py | 11 +- dimos/project/test_no_init_files.py | 11 +- dimos/robot/all_blueprints.py | 10 +- dimos/robot/cli/dimos.py | 9 + dimos/robot/connection_registry.py | 64 +++ dimos/robot/drone/connection_module.py | 31 +- dimos/robot/tf_utils.py | 64 +++ dimos/robot/unitree/connection.py | 435 ------------------ dimos/robot/unitree/g1/__init__.py | 21 + .../g1/blueprints/basic/unitree_g1_basic.py | 4 +- .../blueprints/basic/unitree_g1_basic_sim.py | 4 +- dimos/robot/unitree/g1/connection.py | 95 ++-- .../unitree/g1/effectors/high_level/webrtc.py | 36 +- dimos/robot/unitree/g1/mujoco_sim.py | 143 +----- dimos/robot/unitree/go2/__init__.py | 23 + .../go2/blueprints/agentic/_common_agentic.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 4 +- .../basic/unitree_go2_coordinator.py | 10 +- ...unitree_go2_webrtc_rage_keyboard_teleop.py | 4 +- .../blueprints/smart/unitree_go2_detection.py | 4 +- .../blueprints/smart/unitree_go2_spatial.py | 4 +- dimos/robot/unitree/go2/camera.py | 36 ++ dimos/robot/unitree/go2/config.py | 38 ++ dimos/robot/unitree/go2/connection.py | 386 ---------------- dimos/robot/unitree/go2/connection_mujoco.py | 97 ++++ dimos/robot/unitree/go2/connection_replay.py | 166 +++++++ dimos/robot/unitree/go2/connection_webrtc.py | 322 +++++++++++++ dimos/robot/unitree/go2/fleet_connection.py | 127 +++-- .../robot/unitree/mujoco_camera_constants.py | 51 ++ dimos/robot/unitree/mujoco_connection.py | 263 ++++++----- dimos/robot/unitree/type/map.py | 12 - dimos/robot/unitree/webrtc_session.py | 186 ++++++++ dimos/utils/testing/test_moment.py | 7 +- 42 files changed, 1719 insertions(+), 1258 deletions(-) create mode 100644 dimos/robot/connection_registry.py create mode 100644 dimos/robot/tf_utils.py delete mode 100644 dimos/robot/unitree/connection.py create mode 100644 dimos/robot/unitree/g1/__init__.py create mode 100644 dimos/robot/unitree/go2/__init__.py create mode 100644 dimos/robot/unitree/go2/camera.py create mode 100644 dimos/robot/unitree/go2/config.py delete mode 100644 dimos/robot/unitree/go2/connection.py create mode 100644 dimos/robot/unitree/go2/connection_mujoco.py create mode 100644 dimos/robot/unitree/go2/connection_replay.py create mode 100644 dimos/robot/unitree/go2/connection_webrtc.py create mode 100644 dimos/robot/unitree/mujoco_camera_constants.py create mode 100644 dimos/robot/unitree/webrtc_session.py diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index d8fc0beb18..b4ae8abe90 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -81,9 +81,9 @@ def __init__(self, **kwargs: Any) -> None: # Use MuJoCo camera intrinsics in simulation mode camera_info = self.config.camera_info if self.config.g.simulation: - from dimos.robot.unitree.mujoco_connection import MujocoConnection + from dimos.robot.unitree.mujoco_camera_constants import MUJOCO_CAMERA_INFO_STATIC - camera_info = MujocoConnection.camera_info_static + camera_info = MUJOCO_CAMERA_INFO_STATIC self._visual_servo = VisualServoing2D(camera_info, self.config.g.simulation) self._detection_navigation = DetectionNavigation(self.tf, camera_info) diff --git a/dimos/core/coordination/blueprints.py b/dimos/core/coordination/blueprints.py index 7d94555de4..f02d87125d 100644 --- a/dimos/core/coordination/blueprints.py +++ b/dimos/core/coordination/blueprints.py @@ -195,6 +195,66 @@ def requirements(self, *checks: Callable[[], str | None]) -> "Blueprint": def configurators(self, *checks: "SystemConfigurator") -> "Blueprint": return replace(self, configurator_checks=self.configurator_checks + tuple(checks)) + def with_backend(self, backend: str) -> "Blueprint": + """Swap tagged connection modules for the matching `(robot, backend)` variant. + + For each atom whose module carries a `_connection_tag` with a backend + different from the requested one, look up the same-robot module for + `backend` in the connection registry and substitute it. Streams, + remappings, and disabled-modules entries are rewritten to point at the + new class. + + If the blueprint has no tagged atoms this is a no-op (with a warning). + """ + # Lazy import to keep blueprints.py free of robot deps. + from dimos.robot.connection_registry import backends_for, get_connection + + swap_map: dict[type[ModuleBase], type[ModuleBase]] = {} + for atom in self.blueprints: + tag = getattr(atom.module, "_connection_tag", None) + if tag is None or tag.backend == backend: + continue + target = get_connection(tag.robot, backend) + if target is None: + available = sorted(backends_for(tag.robot)) + raise ValueError( + f"No connection registered for robot={tag.robot!r} " + f"backend={backend!r} (have: {available})" + ) + swap_map[atom.module] = target + + if not swap_map: + tagged = any(getattr(a.module, "_connection_tag", None) for a in self.blueprints) + if not tagged: + logger.warning( + "Blueprint.with_backend(%r) had no tagged connection atoms — " + "returning blueprint unchanged", + backend, + ) + return self + + new_atoms: list[BlueprintAtom] = [] + for atom in self.blueprints: + target = swap_map.get(atom.module) + if target is None: + new_atoms.append(atom) + continue + _check_stream_parity(atom.module, target, atom) + _check_kwargs_compat(target, atom.kwargs) + new_atoms.append(BlueprintAtom.create(target, atom.kwargs)) + + new_remappings = { + (swap_map.get(m, m), name): v for (m, name), v in self.remapping_map.items() + } + new_disabled = tuple(swap_map.get(m, m) for m in self.disabled_modules_tuple) + + return replace( + self, + blueprints=tuple(new_atoms), + remapping_map=MappingProxyType(new_remappings), + disabled_modules_tuple=new_disabled, + ) + @cached_property def active_blueprints(self) -> tuple[BlueprintAtom, ...]: if not self.disabled_modules_tuple: @@ -239,3 +299,38 @@ def _eliminate_duplicates(blueprints: list[BlueprintAtom]) -> list[BlueprintAtom seen.add(bp.module) unique_blueprints.append(bp) return list(reversed(unique_blueprints)) + + +def _stream_signature(streams: tuple[StreamRef, ...]) -> set[tuple[str, str]]: + return {(s.name, s.direction) for s in streams} + + +def _check_stream_parity(old: type[ModuleBase], new: type[ModuleBase], atom: BlueprintAtom) -> None: + new_atom = BlueprintAtom.create(new, atom.kwargs) + old_sig = _stream_signature(atom.streams) + new_sig = _stream_signature(new_atom.streams) + if old_sig != new_sig: + only_old = sorted(old_sig - new_sig) + only_new = sorted(new_sig - old_sig) + raise ValueError( + f"Stream surface drift swapping {old.__name__} -> {new.__name__}: " + f"only on {old.__name__}={only_old}, only on {new.__name__}={only_new}" + ) + + +def _check_kwargs_compat(new: type[ModuleBase], kwargs: dict[str, Any]) -> None: + if not kwargs: + return + try: + config_type = get_type_hints(new).get("config") + except Exception: + return + if config_type is None: + return + valid_fields = set(getattr(config_type, "model_fields", {})) + invalid = set(kwargs) - valid_fields + if invalid: + raise ValueError( + f"Kwargs from blueprint atom are incompatible with {new.__name__}'s " + f"config ({config_type.__name__}): unknown field(s) {sorted(invalid)}" + ) diff --git a/dimos/core/coordination/test_blueprints.py b/dimos/core/coordination/test_blueprints.py index f91d047d91..081df14e98 100644 --- a/dimos/core/coordination/test_blueprints.py +++ b/dimos/core/coordination/test_blueprints.py @@ -33,9 +33,11 @@ autoconnect, ) from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport +from dimos.robot import connection_registry +from dimos.robot.connection_registry import connection from dimos.spec.utils import Spec @@ -232,3 +234,139 @@ def test_active_blueprints_filters_disabled() -> None: active_modules = {bp.module for bp in blueprint.active_blueprints} assert ModuleA not in active_modules assert ModuleB in active_modules + + +@pytest.fixture +def isolated_registry(monkeypatch): + monkeypatch.setattr(connection_registry, "_REGISTRY", {}) + yield connection_registry._REGISTRY + + +class _BotConfig(ModuleConfig): + setting: str = "default" + + +def _bot_modules(): + """Create three (robot=bot) connection variants in an isolated registry.""" + + @connection(robot="bot", backend="real") + class BotReal(Module): + config: _BotConfig + cmd: In[Data1] + odom: Out[Data2] + + @connection(robot="bot", backend="sim") + class BotSim(Module): + config: _BotConfig + cmd: In[Data1] + odom: Out[Data2] + + @connection(robot="bot", backend="replay") + class BotReplay(Module): + config: ModuleConfig + cmd: In[Data1] + odom: Out[Data2] + + return BotReal, BotSim, BotReplay + + +def test_with_backend_no_op_when_no_tagged_atoms(isolated_registry) -> None: + blueprint = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()) + swapped = blueprint.with_backend("sim") + assert swapped is blueprint + + +def test_with_backend_swaps_tagged_atom(isolated_registry) -> None: + BotReal, BotSim, _ = _bot_modules() + + blueprint = autoconnect(ModuleA.blueprint(), BotReal.blueprint(setting="x")) + swapped = blueprint.with_backend("sim") + + swapped_modules = [a.module for a in swapped.blueprints] + assert BotSim in swapped_modules + assert BotReal not in swapped_modules + assert ModuleA in swapped_modules + + bot_atom = next(a for a in swapped.blueprints if a.module is BotSim) + assert bot_atom.kwargs == {"setting": "x"} + # Streams were re-extracted from the new class. + assert {s.name for s in bot_atom.streams} == {"cmd", "odom"} + + +def test_with_backend_no_op_when_already_target(isolated_registry) -> None: + BotReal, _, _ = _bot_modules() + + blueprint = BotReal.blueprint() + swapped = blueprint.with_backend("real") + assert swapped is blueprint # no atoms needed swapping; returns self + + +def test_with_backend_unknown_backend_raises(isolated_registry) -> None: + BotReal, _, _ = _bot_modules() + + blueprint = BotReal.blueprint() + with pytest.raises(ValueError, match="No connection registered.*backend='nope'"): + blueprint.with_backend("nope") + + +def test_with_backend_rewrites_remappings(isolated_registry) -> None: + BotReal, BotSim, _ = _bot_modules() + + blueprint = BotReal.blueprint().remappings([(BotReal, "cmd", "remapped_cmd")]) + swapped = blueprint.with_backend("sim") + + assert (BotReal, "cmd") not in swapped.remapping_map + assert swapped.remapping_map[(BotSim, "cmd")] == "remapped_cmd" + + +def test_with_backend_rewrites_disabled_modules(isolated_registry) -> None: + BotReal, BotSim, _ = _bot_modules() + + blueprint = autoconnect(BotReal.blueprint(), ModuleA.blueprint()).disabled_modules(BotReal) + swapped = blueprint.with_backend("sim") + + assert BotReal not in swapped.disabled_modules_tuple + assert BotSim in swapped.disabled_modules_tuple + + +def test_with_backend_eager_kwarg_validation_raises(isolated_registry) -> None: + @connection(robot="bot", backend="real") + class BotReal2(Module): + class Cfg(ModuleConfig): + mode: str = "default" + speed: int = 1 + + config: Cfg + cmd: In[Data1] + + @connection(robot="bot", backend="sim") + class BotSim2(Module): + class Cfg(ModuleConfig): + speed: int = 1 # NOTE: no `mode` field + + config: Cfg + cmd: In[Data1] + + blueprint = BotReal2.blueprint(mode="rage") + with pytest.raises(ValueError, match="unknown field.*mode"): + blueprint.with_backend("sim") + + +def test_with_backend_stream_parity_drift_raises(isolated_registry) -> None: + @connection(robot="bot", backend="real") + class BotReal3(Module): + config: ModuleConfig + cmd: In[Data1] + odom: Out[Data2] + + @connection(robot="bot", backend="sim") + class BotSim3(Module): + config: ModuleConfig + cmd: In[Data1] + # missing odom; adds extra stream + + extra: Out[Data3] + + blueprint = BotReal3.blueprint() + with pytest.raises(ValueError, match="Stream surface drift"): + blueprint.with_backend("sim") diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 77bd677227..afb8a19e95 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -37,6 +37,7 @@ class GlobalConfig(BaseSettings): xarm6_ip: str | None = None can_port: str | None = None simulation: bool = False + simulator: str | None = None replay: bool = False replay_db: str = "go2_short" new_memory: bool = False @@ -83,10 +84,21 @@ def update(self, **kwargs: object) -> None: def unitree_connection_type(self) -> str: if self.replay: return "replay" + if self.simulator: + return self.simulator if self.simulation: return "mujoco" return "webrtc" + @property + def effective_simulator(self) -> str | None: + """Resolved simulator backend from --simulator or --simulation.""" + if self.simulator: + return self.simulator + if self.simulation: + return "mujoco" + return None + @property def mujoco_start_pos_float(self) -> tuple[float, float]: x, y = _get_all_numbers(self.mujoco_start_pos) diff --git a/dimos/core/module.py b/dimos/core/module.py index 259118098f..72d7747050 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -105,6 +105,12 @@ class _BlueprintPartial(Protocol): def __call__(self, **kwargs: Any) -> "Blueprint": ... +@dataclass(frozen=True) +class ConnectionTag: + robot: str + backend: str + + class ModuleBase(Configurable, CompositeResource): config: ModuleConfig @@ -123,6 +129,7 @@ class ModuleBase(Configurable, CompositeResource): _main_gen: AsyncGenerator[None, None] | None = None _tools: dict[str, Any] _tools_lock: threading.Lock + _connection_tag: ConnectionTag | None = None def __init__(self, config_args: dict[str, Any]) -> None: super().__init__(**config_args) diff --git a/dimos/core/test_global_config.py b/dimos/core/test_global_config.py index d42d004bd2..62868d0a6c 100644 --- a/dimos/core/test_global_config.py +++ b/dimos/core/test_global_config.py @@ -12,16 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for GlobalConfig security defaults.""" +from dimos.core.global_config import GlobalConfig class TestGlobalConfigSecurityDefaults: """Network services must bind to localhost by default (not 0.0.0.0).""" def test_listen_host_defaults_to_localhost(self) -> None: - from dimos.core.global_config import GlobalConfig - config = GlobalConfig() assert config.listen_host == "127.0.0.1", ( f"listen_host must default to 127.0.0.1, got {config.listen_host}" ) + + +class TestSimulatorBackendResolution: + """`--simulator` and `--simulation` translate into the connection backend.""" + + def test_simulator_takes_precedence_over_simulation(self) -> None: + config = GlobalConfig(simulation=True, simulator="simsim") + assert config.effective_simulator == "simsim" + assert config.unitree_connection_type == "simsim" + + def test_simulation_back_compat_resolves_to_mujoco(self) -> None: + config = GlobalConfig(simulation=True) + assert config.effective_simulator == "mujoco" + assert config.unitree_connection_type == "mujoco" + + def test_neither_set_returns_none_and_webrtc(self) -> None: + config = GlobalConfig(simulation=False, simulator=None) + assert config.effective_simulator is None + assert config.unitree_connection_type == "webrtc" + + def test_replay_overrides_simulator(self) -> None: + config = GlobalConfig(replay=True, simulator="mujoco") + assert config.unitree_connection_type == "replay" diff --git a/dimos/e2e_tests/dimos_cli_call.py b/dimos/e2e_tests/dimos_cli_call.py index 569dcfa386..6130175a61 100644 --- a/dimos/e2e_tests/dimos_cli_call.py +++ b/dimos/e2e_tests/dimos_cli_call.py @@ -34,7 +34,7 @@ def start(self) -> None: args = ["run", *args] self.process = subprocess.Popen( - ["dimos", "--simulation", *args], + ["dimos", "--simulator=mujoco", *args], start_new_session=True, ) diff --git a/dimos/experimental/security_demo/security_module.py b/dimos/experimental/security_demo/security_module.py index 9dffc5714c..e62665b5c0 100644 --- a/dimos/experimental/security_demo/security_module.py +++ b/dimos/experimental/security_demo/security_module.py @@ -119,9 +119,9 @@ def _create_visual_servo( ) -> VisualServoing2D: camera_info = config.camera_info if global_config.simulation: - from dimos.robot.unitree.mujoco_connection import MujocoConnection + from dimos.robot.unitree.mujoco_camera_constants import MUJOCO_CAMERA_INFO_STATIC - camera_info = MujocoConnection.camera_info_static + camera_info = MUJOCO_CAMERA_INFO_STATIC return VisualServoing2D(camera_info, global_config.simulation) diff --git a/dimos/memory2/test_visualizer.py b/dimos/memory2/test_visualizer.py index 0830c946fd..495cd38aa0 100644 --- a/dimos/memory2/test_visualizer.py +++ b/dimos/memory2/test_visualizer.py @@ -29,7 +29,7 @@ from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.camera import GO2_CAMERA_INFO_STATIC from dimos.utils.data import get_data, get_data_dir if TYPE_CHECKING: @@ -192,7 +192,7 @@ def test_detect_objects_smart(self, store: SqliteStore, clip: CLIPModel) -> None det3d = Detection3DPC.from_2d( det, lidar_frame, - camera_info=GO2Connection.camera_info_static, + camera_info=GO2_CAMERA_INFO_STATIC, world_to_optical_transform=Transform( ts=obs.ts, translation=obs.pose_stamped.position, diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index afd11cabf6..afc8e94bde 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -37,7 +37,8 @@ from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC from dimos.protocol.tf.tf import TF -from dimos.robot.unitree.go2 import connection +from dimos.robot.tf_utils import odom_to_tf +from dimos.robot.unitree.go2.camera import _camera_info_static from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data @@ -101,7 +102,7 @@ def moment_provider(**kwargs) -> Moment: if odom_frame is None: raise ValueError("No odom frame found") - transforms = connection.GO2Connection._odom_to_tf(odom_frame) + transforms = odom_to_tf(odom_frame) tf.receive_transform(*transforms) @@ -109,7 +110,7 @@ def moment_provider(**kwargs) -> Moment: "odom_frame": odom_frame, "lidar_frame": lidar_frame, "image_frame": image_frame, - "camera_info": connection._camera_info_static(), + "camera_info": _camera_info_static(), "transforms": transforms, "tf": tf, } @@ -266,8 +267,8 @@ def object_db_module(get_moment): c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) - module3d = Detection3DModule(camera_info=connection._camera_info_static()) - moduleDB = ObjectDBModule(camera_info=connection._camera_info_static()) + module3d = Detection3DModule(camera_info=_camera_info_static()) + moduleDB = ObjectDBModule(camera_info=_camera_info_static()) # Process 5 frames to build up object history for i in range(5): diff --git a/dimos/project/test_no_init_files.py b/dimos/project/test_no_init_files.py index 22ebb729c3..605f43f08c 100644 --- a/dimos/project/test_no_init_files.py +++ b/dimos/project/test_no_init_files.py @@ -17,9 +17,16 @@ def test_no_init_files(): dimos_dir = DIMOS_PROJECT_ROOT / "dimos" + + allowed = { + dimos_dir / "__init__.py", + dimos_dir / "robot/unitree/g1/__init__.py", + dimos_dir / "robot/unitree/go2/__init__.py", + } + init_files = sorted(dimos_dir.rglob("__init__.py")) - # The root dimos/__init__.py is allowed for the porcelain lazy import. - init_files = [f for f in init_files if f != dimos_dir / "__init__.py"] + init_files = [f for f in init_files if f not in allowed] + if init_files: listing = "\n".join(f" - {f.relative_to(dimos_dir)}" for f in init_files) raise AssertionError( diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index b88159dd8e..6d5bc4ff7a 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -134,15 +134,16 @@ "far-planner": "dimos.navigation.nav_stack.modules.far_planner.far_planner.FarPlanner", "fast-lio2": "dimos.hardware.sensors.lidar.fastlio2.module.FastLio2", "foxglove-bridge": "dimos.robot.foxglove_bridge.FoxgloveBridge", - "g1-connection": "dimos.robot.unitree.g1.connection.G1Connection", - "g1-connection-base": "dimos.robot.unitree.g1.connection.G1ConnectionBase", "g1-high-level-dds-sdk": "dimos.robot.unitree.g1.effectors.high_level.dds_sdk.G1HighLevelDdsSdk", "g1-high-level-web-rtc": "dimos.robot.unitree.g1.effectors.high_level.webrtc.G1HighLevelWebRtc", - "g1-sim-connection": "dimos.robot.unitree.g1.mujoco_sim.G1SimConnection", + "g1-mujoco-connection": "dimos.robot.unitree.g1.mujoco_sim.G1MujocoConnection", + "g1-web-rtc-connection": "dimos.robot.unitree.g1.connection.G1WebRtcConnection", "g1-whole-body-connection": "dimos.robot.unitree.g1.wholebody_connection.G1WholeBodyConnection", - "go2-connection": "dimos.robot.unitree.go2.connection.GO2Connection", "go2-fleet-connection": "dimos.robot.unitree.go2.fleet_connection.Go2FleetConnection", "go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2.Go2Memory", + "go2-mujoco-connection": "dimos.robot.unitree.go2.connection_mujoco.Go2MujocoConnection", + "go2-replay-connection": "dimos.robot.unitree.go2.connection_replay.Go2ReplayConnection", + "go2-web-rtc-connection": "dimos.robot.unitree.go2.connection_webrtc.Go2WebRtcConnection", "google-maps-skill-container": "dimos.agents.skills.google_maps_skill_container.GoogleMapsSkillContainer", "gps-nav-skill-container": "dimos.agents.skills.gps_nav_skill.GpsNavSkillContainer", "grasp-gen-module": "dimos.manipulation.grasping.graspgen_module.GraspGenModule", @@ -162,6 +163,7 @@ "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", "movement-manager": "dimos.navigation.movement_manager.movement_manager.MovementManager", + "mujoco-connection-base": "dimos.robot.unitree.mujoco_connection.MujocoConnectionBase", "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "nav-record": "dimos.navigation.nav_stack.modules.nav_record.nav_record.NavRecord", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 7506fbd375..050bed8972 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -256,8 +256,17 @@ def run( # Workers inherit DIMOS_RUN_LOG_DIR env var via forkserver. set_run_log_dir(log_dir) + # Apply CLI overrides to global_config before importing blueprints so any + # blueprint-time conditionals (e.g. reading global_config.simulator) see + # the requested values. + if cli_config_overrides: + global_config.update(**cli_config_overrides) + blueprint = autoconnect(*map(get_by_name_or_exit, robot_types)) + if backend := global_config.effective_simulator: + blueprint = blueprint.with_backend(backend) + if disable: disabled_classes = tuple( get_module_by_name_or_exit(name).blueprints[0].module for name in disable diff --git a/dimos/robot/connection_registry.py b/dimos/robot/connection_registry.py new file mode 100644 index 0000000000..cf42863b94 --- /dev/null +++ b/dimos/robot/connection_registry.py @@ -0,0 +1,64 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry for per-(robot, backend) connection modules. + +A connection module declares its identity with the `@connection` decorator: + + @connection(robot="go2", backend="webrtc") + class Go2WebRtcConnection(Module): ... + + @connection(robot="go2", backend="mujoco") + class Go2MujocoConnection(Module): ... + +`Blueprint.with_backend("mujoco")` walks a blueprint's atoms and, for each +tagged atom whose backend differs from the requested one, looks up the +same-robot module for the requested backend in `_REGISTRY` and substitutes it. +""" + +from collections.abc import Callable +from typing import TypeVar + +from dimos.core.module import ConnectionTag, ModuleBase + +T = TypeVar("T", bound=type[ModuleBase]) + + +_REGISTRY: dict[tuple[str, str], type[ModuleBase]] = {} + + +def connection(*, robot: str, backend: str) -> Callable[[T], T]: + """Class decorator that tags a Module as the (robot, backend) connection.""" + + def deco(cls: T) -> T: + tag = ConnectionTag(robot=robot, backend=backend) + existing = _REGISTRY.get((robot, backend)) + if existing is not None and existing is not cls: + raise ValueError( + f"Duplicate connection registration for ({robot!r}, {backend!r}): " + f"{existing.__name__} and {cls.__name__}" + ) + cls._connection_tag = tag + _REGISTRY[(robot, backend)] = cls + return cls + + return deco + + +def get_connection(robot: str, backend: str) -> type[ModuleBase] | None: + return _REGISTRY.get((robot, backend)) + + +def backends_for(robot: str) -> set[str]: + return {backend for (r, backend) in _REGISTRY if r == robot} diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index a2cec07e79..588d383d58 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -32,12 +32,12 @@ from dimos.mapping.models import LatLon from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.drone.dji_video_stream import DJIDroneVideoStream from dimos.robot.drone.mavlink_connection import MavlinkConnection +from dimos.robot.tf_utils import odom_to_tf from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -169,31 +169,16 @@ def _store_and_publish_frame(self, frame: Image) -> None: self.video.publish(frame) def _publish_tf(self, msg: PoseStamped) -> None: - """Publish odometry and TF transforms.""" self._odom = msg - - # Publish odometry self.odom.publish(msg) - - # Publish base_link transform - base_link = Transform( - translation=msg.position, - rotation=msg.orientation, - frame_id="world", - child_frame_id="base_link", - ts=msg.ts if hasattr(msg, "ts") else time.time(), - ) - self.tf.publish(base_link) - - # Publish camera_link transform (camera mounted on front of drone, no gimbal factored in yet) - camera_link = Transform( - translation=Vector3(0.1, 0.0, -0.05), # 10cm forward, 5cm down - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # No rotation relative to base - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), + # 10cm forward, 5cm down. No optical frame published — gimbal not modeled yet. + self.tf.publish( + *odom_to_tf( + msg, + camera_link_offset=Vector3(0.1, 0.0, -0.05), + with_optical=False, + ) ) - self.tf.publish(camera_link) def _publish_status(self, status: dict[str, Any]) -> None: """Publish drone status as JSON string.""" diff --git a/dimos/robot/tf_utils.py b/dimos/robot/tf_utils.py new file mode 100644 index 0000000000..fd7fdcd027 --- /dev/null +++ b/dimos/robot/tf_utils.py @@ -0,0 +1,64 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# REP-103/REP-105 optical-frame rotation: camera_link -> camera_optical. +_CAMERA_OPTICAL_ROTATION = Quaternion(-0.5, 0.5, -0.5, 0.5) + + +def odom_to_tf( + odom: PoseStamped, + *, + camera_link_offset: Vector3 = Vector3(0.3, 0.0, 0.0), + with_optical: bool = True, + extras: Iterable[Transform] = (), +) -> list[Transform]: + """Build the standard TF chain from an odometry pose. + + Produces, in order: + odom.frame_id -> base_link (from the pose) + base_link -> camera_link (translation = camera_link_offset) + camera_link -> camera_optical (only if with_optical) + ...extras + """ + transforms: list[Transform] = [Transform.from_pose("base_link", odom)] + transforms.append( + Transform( + translation=camera_link_offset, + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + ) + if with_optical: + transforms.append( + Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=_CAMERA_OPTICAL_ROTATION, + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + ) + transforms.extend(extras) + return transforms diff --git a/dimos/robot/unitree/connection.py b/dimos/robot/unitree/connection.py deleted file mode 100644 index 919efc76f6..0000000000 --- a/dimos/robot/unitree/connection.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from dataclasses import dataclass -import functools -import threading -import time -from typing import Any, TypeAlias - -import numpy as np -from numpy.typing import NDArray -from reactivex import operators as ops -from reactivex.observable import Observable -from reactivex.subject import Subject -from unitree_webrtc_connect.constants import ( - RTC_TOPIC, - SPORT_CMD, - VUI_COLOR, -) -from unitree_webrtc_connect.webrtc_driver import ( - UnitreeWebRTCConnection as LegionConnection, - WebRTCConnectionMethod, -) - -from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT -from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs.Pose import Pose -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.robot.unitree.type.lidar import ( - RawLidarMsg, - pointcloud2_from_webrtc_lidar, - repair_stale_ts, -) -from dimos.robot.unitree.type.lowstate import LowStateMsg -from dimos.robot.unitree.type.odometry import Odometry -from dimos.utils.decorators.decorators import simple_mcache -from dimos.utils.reactive import backpressure, callback_to_observable - -VideoMessage: TypeAlias = NDArray[np.uint8] # Shape: (height, width, 3) - - -@dataclass -class SerializableVideoFrame: - """Pickleable wrapper for av.VideoFrame with all metadata""" - - data: np.ndarray - pts: int | None = None - time: float | None = None - dts: int | None = None - width: int | None = None - height: int | None = None - format: str | None = None - - @classmethod - def from_av_frame(cls, frame): # type: ignore[no-untyped-def] - return cls( - data=frame.to_ndarray(format="rgb24"), - pts=frame.pts, - time=frame.time, - dts=frame.dts, - width=frame.width, - height=frame.height, - format=frame.format.name if hasattr(frame, "format") and frame.format else None, - ) - - def to_ndarray(self, format=None): # type: ignore[no-untyped-def] - return self.data - - -class UnitreeWebRTCConnection(Resource): - _SPORT_API_ID_RAGEMODE: int = 2059 - - def __init__(self, ip: str, mode: str = "ai") -> None: - self.ip = ip - self.mode = mode - self.stop_timer: threading.Timer | None = None - self.cmd_vel_timeout = 0.2 - self.conn = LegionConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) - self.connect() - - def connect(self) -> None: - self.loop = asyncio.new_event_loop() - self.task = None - self.connected_event = asyncio.Event() - self.connection_ready = threading.Event() - - async def async_connect() -> None: - await self.conn.connect() - await self.conn.datachannel.disableTrafficSaving(True) - - self.conn.datachannel.set_decoder(decoder_type="native") - - await self.conn.datachannel.pub_sub.publish_request_new( - RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} - ) - - self.connected_event.set() - self.connection_ready.set() - - while True: - await asyncio.sleep(1) - - def start_background_loop() -> None: - asyncio.set_event_loop(self.loop) - self.task = self.loop.create_task(async_connect()) - self.loop.run_forever() - - self.loop = asyncio.new_event_loop() - self.thread = threading.Thread(target=start_background_loop, daemon=True) - self.thread.start() - self.connection_ready.wait() - - def start(self) -> None: - pass - - def stop(self) -> None: - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if self.task: - self.task.cancel() - - async def async_disconnect() -> None: - try: - # Send stop command directly since we're already in the event loop. - self.conn.datachannel.pub_sub.publish_without_callback( - RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": 0, "ly": 0, "rx": 0, "ry": 0}, - ) - await self.conn.disconnect() - except Exception: - pass - - if self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - self.loop.call_soon_threadsafe(self.loop.stop) - - if self.thread.is_alive(): - self.thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - - def move(self, twist: Twist, duration: float = 0.0) -> bool: - """Send movement command to the robot using Twist commands. - - Args: - twist: Twist message with linear and angular velocities - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z - - # WebRTC coordinate mapping: - # x - Positive right, negative left - # y - positive forward, negative backwards - # yaw - Positive rotate right, negative rotate left - async def async_move() -> None: - self.conn.datachannel.pub_sub.publish_without_callback( - RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, - ) - - async def async_move_duration() -> None: - """Send movement commands continuously for the specified duration.""" - start_time = time.time() - sleep_time = 0.01 - - while time.time() - start_time < duration: - await async_move() - await asyncio.sleep(sleep_time) - - # Cancel existing timer and start a new one - if self.stop_timer: - self.stop_timer.cancel() - - # Auto-stop after 0.5 seconds if no new commands - self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop_movement) - self.stop_timer.daemon = True - self.stop_timer.start() - - try: - if duration > 0: - # Send continuous move commands for the duration - future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) - future.result() - # Stop after duration - self.stop_movement() - else: - # Single command for continuous movement - future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) - future.result() - return True - except Exception as e: - print(f"Failed to send movement command: {e}") - return False - - # Generic conversion of unitree subscription to Subject (used for all subs) - def unitree_sub_stream(self, topic_name: str): # type: ignore[no-untyped-def] - def subscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] - # Run the subscription in the background thread that has the event loop - def run_subscription() -> None: - self.conn.datachannel.pub_sub.subscribe(topic_name, cb) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_subscription) - - def unsubscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] - # Run the unsubscription in the background thread that has the event loop - def run_unsubscription() -> None: - self.conn.datachannel.pub_sub.unsubscribe(topic_name) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_unsubscription) - - return callback_to_observable( - start=subscribe_in_thread, - stop=unsubscribe_in_thread, - ) - - # Generic sync API call (we jump into the client thread) - def publish_request(self, topic: str, data: dict[Any, Any]) -> Any: - future = asyncio.run_coroutine_threadsafe( - self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop - ) - return future.result() - - @simple_mcache - def raw_lidar_stream(self) -> Observable[RawLidarMsg]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) - - @simple_mcache - def raw_odom_stream(self) -> Observable[Pose]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) - - @simple_mcache - def lidar_stream(self) -> Observable[PointCloud2]: - return backpressure( - self.raw_lidar_stream().pipe( - ops.map(pointcloud2_from_webrtc_lidar), - repair_stale_ts(), - ) - ) - - @simple_mcache - def tf_stream(self) -> Observable[Transform]: - base_link = functools.partial(Transform.from_pose, "base_link") - return backpressure(self.odom_stream().pipe(ops.map(base_link))) - - @simple_mcache - def odom_stream(self) -> Observable[Pose]: - return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) - - @simple_mcache - def video_stream(self) -> Observable[Image]: - return backpressure( - self.raw_video_stream().pipe( - ops.filter(lambda frame: frame is not None), - ops.map( - lambda frame: Image.from_numpy( - # np.ascontiguousarray(frame.to_ndarray("rgb24")), - frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] - format=ImageFormat.RGB, # Frame is RGB24, not BGR - frame_id="camera_optical", - ) - ), - ) - ) - - @simple_mcache - def lowstate_stream(self) -> Observable[LowStateMsg]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) - - def standup(self) -> bool: - return bool(self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]})) - - def balance_stand(self) -> bool: - """Activate BalanceStand mode — enables WIRELESS_CONTROLLER joystick commands.""" - return bool( - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - ) - - def set_obstacle_avoidance(self, enabled: bool = True) -> None: - self.publish_request( - RTC_TOPIC["OBSTACLES_AVOID"], - {"api_id": 1001, "parameter": {"enable": int(enabled)}}, - ) - - def free_walk(self) -> bool: - """Activate FreeWalk locomotion mode — enables walking and velocity commands.""" - return bool(self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["FreeWalk"]})) - - def enable_rage_mode(self) -> bool: - """Enable Rage Mode on the Go2 via WebRTC. - Assumes the robot is already in BalanceStand. - """ - rage_ok = bool( - self.publish_request( - RTC_TOPIC["SPORT_MOD"], - {"api_id": self._SPORT_API_ID_RAGEMODE, "parameter": {"data": True}}, - ) - ) - time.sleep(2.0) - - joystick_ok = bool( - self.publish_request( - RTC_TOPIC["SPORT_MOD"], - { - "api_id": SPORT_CMD["SwitchJoystick"], - "parameter": {"data": True}, - }, - ) - ) - return rage_ok and joystick_ok - - def liedown(self) -> bool: - return bool( - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) - ) - - async def handstand(self): # type: ignore[no-untyped-def] - return self.publish_request( - RTC_TOPIC["SPORT_MOD"], - {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, - ) - - def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: - return self.publish_request( # type: ignore[no-any-return] - RTC_TOPIC["VUI"], - { - "api_id": 1001, - "parameter": { - "color": color, - "time": colortime, - }, - }, - ) - - @simple_mcache - def raw_video_stream(self) -> Observable[VideoMessage]: - subject: Subject[VideoMessage] = Subject() - stop_event = threading.Event() - - from aiortc import MediaStreamTrack - - async def accept_track(track: MediaStreamTrack) -> None: - while True: - if stop_event.is_set(): - return - frame = await track.recv() - serializable_frame = SerializableVideoFrame.from_av_frame(frame) # type: ignore[no-untyped-call] - subject.on_next(serializable_frame) - - self.conn.video.add_track_callback(accept_track) - - # Run the video channel switching in the background thread - def switch_video_channel() -> None: - self.conn.video.switchVideoChannel(True) - - self.loop.call_soon_threadsafe(switch_video_channel) - - def stop() -> None: - stop_event.set() # Signal the loop to stop - self.conn.video.track_callbacks.remove(accept_track) - - # Run the video channel switching off in the background thread - def switch_video_channel_off() -> None: - self.conn.video.switchVideoChannel(False) - - self.loop.call_soon_threadsafe(switch_video_channel_off) - - return subject.pipe(ops.finally_action(stop)) - - def get_video_stream(self, fps: int = 30) -> Observable[Image]: - """Get the video stream from the robot's camera. - - Implements the AbstractRobot interface method. - - Args: - fps: Frames per second. This parameter is included for API compatibility, - but doesn't affect the actual frame rate which is determined by the camera. - - Returns: - Observable: An observable stream of video frames or None if video is not available. - """ - return self.video_stream() - - def stop_movement(self) -> None: - """Cancel the auto-stop timer (used by move() for continuous commands).""" - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if hasattr(self, "task") and self.task: - self.task.cancel() - if hasattr(self, "conn"): - - async def async_disconnect() -> None: - try: - await self.conn.disconnect() - except: - pass - - if hasattr(self, "loop") and self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - if hasattr(self, "loop") and self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) - - if hasattr(self, "thread") and self.thread.is_alive(): - self.thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) diff --git a/dimos/robot/unitree/g1/__init__.py b/dimos/robot/unitree/g1/__init__.py new file mode 100644 index 0000000000..b68ec1e22c --- /dev/null +++ b/dimos/robot/unitree/g1/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Eager-import G1 connection variants so the connection registry is fully +# populated by the time `Blueprint.with_backend(...)` runs. The Mujoco +# variant defers its mujoco-engine import to instance __init__. +from dimos.robot.unitree.g1 import ( # noqa: F401 + connection, + mujoco_sim, +) diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py index 98248916ea..3b6bb104e5 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic.py @@ -19,11 +19,11 @@ from dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav import ( uintree_g1_primitive_no_nav, ) -from dimos.robot.unitree.g1.connection import G1Connection +from dimos.robot.unitree.g1.connection import G1WebRtcConnection unitree_g1_basic = autoconnect( uintree_g1_primitive_no_nav, - G1Connection.blueprint(), + G1WebRtcConnection.blueprint(), ) __all__ = ["unitree_g1_basic"] diff --git a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py index 32d2d52b8b..edafe442be 100644 --- a/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py +++ b/dimos/robot/unitree/g1/blueprints/basic/unitree_g1_basic_sim.py @@ -20,11 +20,11 @@ from dimos.robot.unitree.g1.blueprints.primitive.uintree_g1_primitive_no_nav import ( uintree_g1_primitive_no_nav, ) -from dimos.robot.unitree.g1.mujoco_sim import G1SimConnection +from dimos.robot.unitree.g1.mujoco_sim import G1MujocoConnection unitree_g1_basic_sim = autoconnect( uintree_g1_primitive_no_nav, - G1SimConnection.blueprint(), + G1MujocoConnection.blueprint(), ReplanningAStarPlanner.blueprint(), ) diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index 58b2a0747f..596ca1753d 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -12,107 +12,66 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +"""Real-hardware G1 connection over Unitree's WebRTC stack. + +Composes `UnitreeWebRtcSession` for the asyncio loop / handshake / move / +publish_request. +""" + +from __future__ import annotations + +from typing import Any from pydantic import Field from reactivex.disposable import Disposable -from dimos.core.coordination.module_coordinator import ModuleCoordinator from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.robot.unitree.connection import UnitreeWebRTCConnection -from dimos.spec.control import LocalPlanner +from dimos.robot.connection_registry import connection +from dimos.robot.unitree.webrtc_session import UnitreeWebRtcSession from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.core.rpc_client import ModuleProxy - logger = setup_logger() class G1Config(ModuleConfig): ip: str = Field(default_factory=lambda m: m["g"].robot_ip) - connection_type: str = Field(default_factory=lambda m: m["g"].unitree_connection_type) - - -class G1ConnectionBase(Module, ABC): - """Abstract base for G1 connections (real hardware and simulation). - - Modules that depend on G1 connection RPC methods should reference this - base class so the blueprint wiring works regardless of which concrete - connection is deployed. - """ - - config: ModuleConfig - - @rpc - @abstractmethod - def start(self) -> None: - super().start() - @rpc - @abstractmethod - def stop(self) -> None: - super().stop() - - @rpc - @abstractmethod - def move(self, twist: Twist, duration: float = 0.0) -> None: ... - - @rpc - @abstractmethod - def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: ... +@connection(robot="g1", backend="webrtc") +class G1WebRtcConnection(Module): + """Real-hardware G1 connection over Unitree's WebRTC stack.""" -class G1Connection(G1ConnectionBase): config: G1Config cmd_vel: In[Twist] - connection: UnitreeWebRTCConnection | None = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.session: UnitreeWebRtcSession | None = None @rpc def start(self) -> None: super().start() - - match self.config.connection_type: - case "webrtc": - self.connection = UnitreeWebRTCConnection(self.config.ip) - case "replay": - raise ValueError("Replay connection not implemented for G1 robot") - case "mujoco": - raise ValueError( - "This module does not support simulation, use G1SimConnection instead" - ) - case _: - raise ValueError(f"Unknown connection type: {self.config.connection_type}") - - assert self.connection is not None - self.connection.start() - + self.session = UnitreeWebRtcSession(self.config.ip) + self.session.start() self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) @rpc def stop(self) -> None: - assert self.connection is not None - self.connection.stop() + if self.session is not None: + self.session.stop() + self.session = None super().stop() @rpc def move(self, twist: Twist, duration: float = 0.0) -> None: - assert self.connection is not None - self.connection.move(twist, duration) + assert self.session is not None + self.session.move(twist, duration) @rpc def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: logger.info(f"Publishing request to topic: {topic} with data: {data}") - assert self.connection is not None - return self.connection.publish_request(topic, data) # type: ignore[no-any-return] - - -def deploy(dimos: ModuleCoordinator, ip: str, local_planner: LocalPlanner) -> "ModuleProxy": - connection = dimos.deploy(G1Connection, ip=ip) - connection.cmd_vel.connect(local_planner.cmd_vel) - connection.start() - return connection + assert self.session is not None + return self.session.publish_request(topic, data) # type: ignore[no-any-return] diff --git a/dimos/robot/unitree/g1/effectors/high_level/webrtc.py b/dimos/robot/unitree/g1/effectors/high_level/webrtc.py index 1784605ac2..6ae39ba967 100644 --- a/dimos/robot/unitree/g1/effectors/high_level/webrtc.py +++ b/dimos/robot/unitree/g1/effectors/high_level/webrtc.py @@ -15,6 +15,7 @@ from typing import Any from reactivex.disposable import Disposable +from unitree_webrtc_connect.constants import RTC_TOPIC, SPORT_CMD from dimos.agents.annotation import skill from dimos.core.core import rpc @@ -23,7 +24,6 @@ from dimos.core.stream import In from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.robot.unitree.connection import UnitreeWebRTCConnection from dimos.robot.unitree.g1.effectors.high_level.commands import ( ARM_API_ID, ARM_COMMANDS, @@ -36,6 +36,7 @@ execute_g1_command, ) from dimos.robot.unitree.g1.effectors.high_level.high_level_spec import HighLevelG1Spec +from dimos.robot.unitree.webrtc_session import UnitreeWebRtcSession from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -52,52 +53,57 @@ class G1HighLevelWebRtc(Module, HighLevelG1Spec): cmd_vel: In[Twist] config: G1HighLevelWebRtcConfig - connection: UnitreeWebRTCConnection | None + session: UnitreeWebRtcSession | None def __init__(self, *args: Any, g: GlobalConfig = global_config, **kwargs: Any) -> None: super().__init__(*args, g=g, **kwargs) self._global_config = g + self.session = None @rpc def start(self) -> None: super().start() assert self.config.ip is not None, "ip must be set in G1HighLevelWebRtcConfig" - self.connection = UnitreeWebRTCConnection(self.config.ip, self.config.connection_mode) - self.connection.start() + self.session = UnitreeWebRtcSession(self.config.ip, mode_name=self.config.connection_mode) + self.session.start() self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) @rpc def stop(self) -> None: - if self.connection is not None: - self.connection.stop() + if self.session is not None: + self.session.stop() super().stop() @rpc def move(self, twist: Twist, duration: float = 0.0) -> bool: - assert self.connection is not None - return self.connection.move(twist, duration) + assert self.session is not None + return self.session.move(twist, duration) @rpc def get_state(self) -> str: - if self.connection is None: + if self.session is None: return "Not connected" return "Connected (WebRTC)" @rpc def publish_request(self, topic: str, data: dict[str, Any]) -> dict[str, Any]: logger.info(f"Publishing request to topic: {topic} with data: {data}") - assert self.connection is not None - return self.connection.publish_request(topic, data) # type: ignore[no-any-return] + assert self.session is not None + return self.session.publish_request(topic, data) # type: ignore[no-any-return] @rpc def stand_up(self) -> bool: - assert self.connection is not None - return self.connection.standup() + assert self.session is not None + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + ) @rpc def lie_down(self) -> bool: - assert self.connection is not None - return self.connection.liedown() + assert self.session is not None + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + ) @skill def move_velocity( diff --git a/dimos/robot/unitree/g1/mujoco_sim.py b/dimos/robot/unitree/g1/mujoco_sim.py index f1c9e92310..cde0214b13 100644 --- a/dimos/robot/unitree/g1/mujoco_sim.py +++ b/dimos/robot/unitree/g1/mujoco_sim.py @@ -12,141 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading -from threading import Thread -import time -from typing import Any +"""MuJoCo-simulated G1 connection. + +A thin subclass of `MujocoConnectionBase` that sets the G1's camera mounting +offset and publishes an additional `map -> world` transform. +""" + +from __future__ import annotations from pydantic import Field -from reactivex.disposable import Disposable -from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT -from dimos.core.core import rpc from dimos.core.module import ModuleConfig -from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.robot.unitree.g1.connection import G1ConnectionBase -from dimos.robot.unitree.mujoco_connection import MujocoConnection -from dimos.robot.unitree.type.odometry import Odometry as SimOdometry -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() +from dimos.robot.connection_registry import connection +from dimos.robot.unitree.mujoco_connection import MujocoConnectionBase class G1SimConfig(ModuleConfig): ip: str = Field(default_factory=lambda m: m["g"].robot_ip) -class G1SimConnection(G1ConnectionBase): - config: G1SimConfig - cmd_vel: In[Twist] - lidar: Out[PointCloud2] - odom: Out[PoseStamped] - color_image: Out[Image] - camera_info: Out[CameraInfo] - connection: MujocoConnection | None = None - _camera_info_thread: Thread | None = None - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._stop_event = threading.Event() - - @rpc - def start(self) -> None: - super().start() - - from dimos.robot.unitree.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection(self.config.g) - assert self.connection is not None - self.connection.start() - - self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) - self.register_disposable(self.connection.odom_stream().subscribe(self._publish_sim_odom)) - self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self.register_disposable(self.connection.video_stream().subscribe(self.color_image.publish)) - - self._camera_info_thread = Thread( - target=self._publish_camera_info_loop, - daemon=True, - ) - self._camera_info_thread.start() - - @rpc - def stop(self) -> None: - self._stop_event.set() - assert self.connection is not None - self.connection.stop() - if self._camera_info_thread and self._camera_info_thread.is_alive(): - self._camera_info_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - super().stop() +@connection(robot="g1", backend="mujoco") +class G1MujocoConnection(MujocoConnectionBase): + """MuJoCo-simulated G1 connection.""" - def _publish_camera_info_loop(self) -> None: - assert self.connection is not None - info = self.connection.camera_info_static - while not self._stop_event.is_set(): - self.camera_info.publish(info) - self._stop_event.wait(1.0) - - def _publish_tf(self, msg: PoseStamped) -> None: - self.odom.publish(msg) - - self.tf.publish(Transform.from_pose("base_link", msg)) - - # Publish camera_link and camera_optical transforms - camera_link = Transform( - translation=Vector3(0.05, 0.0, 0.6), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), - ) - - camera_optical = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", - ts=time.time(), - ) - - map_to_world = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="map", - child_frame_id="world", - ts=time.time(), - ) + config: G1SimConfig - self.tf.publish(camera_link, camera_optical, map_to_world) + _camera_link_offset: Vector3 = Vector3(0.05, 0.0, 0.6) - def _publish_sim_odom(self, msg: SimOdometry) -> None: - self._publish_tf( - PoseStamped( + def _extra_transforms(self, msg: PoseStamped) -> list[Transform]: + return [ + Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", ts=msg.ts, - frame_id=msg.frame_id, - position=msg.position, - orientation=msg.orientation, - ) - ) - - @rpc - def move(self, twist: Twist, duration: float = 0.0) -> None: - assert self.connection is not None - self.connection.move(twist, duration) - - @rpc - def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: - logger.info(f"Publishing request to topic: {topic} with data: {data}") - assert self.connection is not None - return self.connection.publish_request(topic, data) - - -__all__ = ["G1SimConnection"] + ), + ] diff --git a/dimos/robot/unitree/go2/__init__.py b/dimos/robot/unitree/go2/__init__.py new file mode 100644 index 0000000000..080f425129 --- /dev/null +++ b/dimos/robot/unitree/go2/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Eager-import all Go2 connection variants so the connection registry is +# fully populated by the time `Blueprint.with_backend(...)` runs. The Mujoco +# variant defers its mujoco-engine import to instance __init__, so registering +# it here is cheap. +from dimos.robot.unitree.go2 import ( + connection_mujoco, # noqa: F401 + connection_replay, # noqa: F401 + connection_webrtc, # noqa: F401 +) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py b/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py index 93312225bc..5fa0085787 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/_common_agentic.py @@ -18,12 +18,12 @@ from dimos.agents.skills.speak_skill import SpeakSkill from dimos.agents.web_human_input import WebInput from dimos.core.coordination.blueprints import autoconnect -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.camera import GO2_CAMERA_INFO_STATIC from dimos.robot.unitree.unitree_skill_container import UnitreeSkillContainer _common_agentic = autoconnect( NavigationSkillContainer.blueprint(), - PersonFollowSkillContainer.blueprint(camera_info=GO2Connection.camera_info_static), + PersonFollowSkillContainer.blueprint(camera_info=GO2_CAMERA_INFO_STATIC), UnitreeSkillContainer.blueprint(), WebInput.blueprint(), SpeakSkill.blueprint(), diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 4f86ccb0a3..53c9e372f5 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -23,7 +23,7 @@ from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.connection_webrtc import Go2WebRtcConnection from dimos.visualization.vis_module import vis_module # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image @@ -132,7 +132,7 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_basic = ( autoconnect( _with_vis, - GO2Connection.blueprint(), + Go2WebRtcConnection.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py index ef8263fbf3..cbfbb9e64f 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_coordinator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unitree Go2 ControlCoordinator — GO2Connection + coordinator via LCM transport adapter. +"""Unitree Go2 ControlCoordinator — Go2WebRtcConnection + coordinator via LCM transport adapter. Usage: dimos run unitree-go2-coordinator @@ -29,13 +29,13 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.connection_webrtc import Go2WebRtcConnection _go2_joints = make_twist_base_joints("go2") unitree_go2_coordinator = ( autoconnect( - GO2Connection.blueprint(), + Go2WebRtcConnection.blueprint(), ControlCoordinator.blueprint( hardware=[ HardwareComponent( @@ -57,8 +57,8 @@ ) .remappings( [ - (GO2Connection, "cmd_vel", "go2_cmd_vel"), - (GO2Connection, "odom", "go2_odom"), + (Go2WebRtcConnection, "cmd_vel", "go2_cmd_vel"), + (Go2WebRtcConnection, "odom", "go2_odom"), ] ) .transports( diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py index c2fd9a3b87..ae80641661 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py @@ -29,12 +29,12 @@ from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_keyboard_teleop import ( unitree_go2_webrtc_keyboard_teleop, ) -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.connection_webrtc import Go2WebRtcConnection from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop unitree_go2_webrtc_rage_keyboard_teleop = autoconnect( unitree_go2_webrtc_keyboard_teleop, - GO2Connection.blueprint(mode="rage"), + Go2WebRtcConnection.blueprint(mode="rage"), KeyboardTeleop.blueprint(linear_speed=1.25, angular_speed=1.2), ).global_config(obstacle_avoidance=True) diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py index 799d5bf211..8ce26b6804 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py @@ -25,13 +25,13 @@ from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.camera import GO2_CAMERA_INFO_STATIC unitree_go2_detection = ( autoconnect( unitree_go2, Detection3DModule.blueprint( - camera_info=GO2Connection.camera_info_static, + camera_info=GO2_CAMERA_INFO_STATIC, ), ) .remappings( diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py index 97a0803b98..7d0cf47ea4 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_spatial.py @@ -18,13 +18,13 @@ from dimos.perception.perceive_loop_skill import PerceiveLoopSkill from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 -from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.go2.camera import GO2_CAMERA_INFO_STATIC unitree_go2_spatial = autoconnect( unitree_go2, SpatialMemory.blueprint(), PerceiveLoopSkill.blueprint(), - SecurityModule.blueprint(camera_info=GO2Connection.camera_info_static), + SecurityModule.blueprint(camera_info=GO2_CAMERA_INFO_STATIC), ).global_config(n_workers=8) __all__ = ["unitree_go2_spatial"] diff --git a/dimos/robot/unitree/go2/camera.py b/dimos/robot/unitree/go2/camera.py new file mode 100644 index 0000000000..b7595d9c85 --- /dev/null +++ b/dimos/robot/unitree/go2/camera.py @@ -0,0 +1,36 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo + + +def _camera_info_static() -> CameraInfo: + fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) + width, height = (1280, 720) + + return CameraInfo( + frame_id="camera_optical", + height=height, + width=width, + distortion_model="plumb_bob", + D=[0.0, 0.0, 0.0, 0.0, 0.0], + K=[fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + P=[fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + binning_x=0, + binning_y=0, + ) + + +GO2_CAMERA_INFO_STATIC: CameraInfo = _camera_info_static() diff --git a/dimos/robot/unitree/go2/config.py b/dimos/robot/unitree/go2/config.py new file mode 100644 index 0000000000..9d00eb96aa --- /dev/null +++ b/dimos/robot/unitree/go2/config.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared config types for Go2 backend connections. + +Kept in one place so the registry-based backend swap sees the same +`ConnectionConfig.model_fields` shape across `Go2WebRtcConnection`, +`Go2MujocoConnection`, `Go2ReplayConnection`, and `Go2FleetConnection`. +""" + +from __future__ import annotations + +from enum import Enum + +from pydantic import Field + +from dimos.core.module import ModuleConfig + + +class Go2Mode(str, Enum): + DEFAULT = "default" + RAGE = "rage" + + +class ConnectionConfig(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + mode: Go2Mode = Go2Mode.DEFAULT diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py deleted file mode 100644 index 98d1423dd6..0000000000 --- a/dimos/robot/unitree/go2/connection.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum -import sys -from threading import Thread -import time -from typing import TYPE_CHECKING, Any, Protocol - -from pydantic import Field -from reactivex.disposable import Disposable -from reactivex.observable import Observable -import rerun.blueprint as rrb - -from dimos.agents.annotation import skill -from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT -from dimos.core.coordination.module_coordinator import ModuleCoordinator -from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.core.transport import LCMTransport, pSHMTransport -from dimos.spec.perception import Camera, Pointcloud -from dimos.utils.logging_config import setup_logger - -if TYPE_CHECKING: - from dimos.core.rpc_client import ModuleProxy -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.robot.unitree.connection import UnitreeWebRTCConnection -from dimos.utils.decorators.decorators import simple_mcache -from dimos.utils.testing.replay import TimedSensorReplay, TimedSensorStorage - -if sys.version_info < (3, 13): - from typing_extensions import TypeVar -else: - from typing import TypeVar - -logger = setup_logger() - - -class Go2Mode(str, Enum): - DEFAULT = "default" - RAGE = "rage" - - -class ConnectionConfig(ModuleConfig): - ip: str = Field(default_factory=lambda m: m["g"].robot_ip) - mode: Go2Mode = Go2Mode.DEFAULT - - -class Go2ConnectionProtocol(Protocol): - """Protocol defining the interface for Go2 robot connections.""" - - def start(self) -> None: ... - def stop(self) -> None: ... - def lidar_stream(self) -> Observable: ... # type: ignore[type-arg] - def odom_stream(self) -> Observable: ... # type: ignore[type-arg] - def video_stream(self) -> Observable: ... # type: ignore[type-arg] - def move(self, twist: Twist, duration: float = 0.0) -> bool: ... - def standup(self) -> bool: ... - def liedown(self) -> bool: ... - def balance_stand(self) -> bool: ... - def set_obstacle_avoidance(self, enabled: bool = True) -> None: ... - def enable_rage_mode(self) -> bool: ... - def publish_request(self, topic: str, data: dict) -> dict: ... # type: ignore[type-arg] - - -def _camera_info_static() -> CameraInfo: - fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) - width, height = (1280, 720) - - return CameraInfo( - frame_id="camera_optical", - height=height, - width=width, - distortion_model="plumb_bob", - D=[0.0, 0.0, 0.0, 0.0, 0.0], - K=[fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], - R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - P=[fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], - binning_x=0, - binning_y=0, - ) - - -# Static camera mount chain: base_link -> camera_link -> camera_optical. -# TODO we need a standardized way to specify this for all cameras in dimos -BASE_TO_OPTICAL: Transform = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", -) + Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", -) - - -def make_connection(ip: str | None, cfg: GlobalConfig) -> Go2ConnectionProtocol: - connection_type = cfg.unitree_connection_type - - if ip in ("fake", "mock", "replay") or connection_type == "replay": - dataset = cfg.replay_db - return ReplayConnection(dataset=dataset) - elif ip == "mujoco" or connection_type == "mujoco": - from dimos.robot.unitree.mujoco_connection import MujocoConnection - - return MujocoConnection(cfg) - else: - assert ip is not None, "IP address must be provided" - return UnitreeWebRTCConnection(ip) - - -class ReplayConnection(UnitreeWebRTCConnection): - # we don't want UnitreeWebRTCConnection to init - def __init__( # type: ignore[no-untyped-def] - self, - dataset: str = "go2_china_office", - **kwargs, - ) -> None: - self.dataset = dataset - self.replay_config = { - "loop": kwargs.get("loop", True), - "seek": kwargs.get("seek"), - "duration": kwargs.get("duration"), - } - - def connect(self) -> None: - pass - - def start(self) -> None: - pass - - def standup(self) -> bool: - return True - - def liedown(self) -> bool: - return True - - def balance_stand(self) -> bool: - return True - - def set_obstacle_avoidance(self, enabled: bool = True) -> None: - pass - - def enable_rage_mode(self) -> bool: - return True - - @simple_mcache - def lidar_stream(self): # type: ignore[no-untyped-def] - lidar_store = TimedSensorReplay(f"{self.dataset}/lidar") # type: ignore[var-annotated] - return lidar_store.stream(**self.replay_config) - - @simple_mcache - def odom_stream(self): # type: ignore[no-untyped-def] - odom_store = TimedSensorReplay(f"{self.dataset}/odom") # type: ignore[var-annotated] - return odom_store.stream(**self.replay_config) - - @simple_mcache - def video_stream(self): # type: ignore[no-untyped-def] - video_store: TimedSensorReplay[Image] = TimedSensorReplay(f"{self.dataset}/color_image") - return video_store.stream(**self.replay_config) - - def move(self, twist: Twist, duration: float = 0.0) -> bool: - return True - - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - """Fake publish request for testing.""" - return {"status": "ok", "message": "Fake publish"} - - -_Config = TypeVar("_Config", bound=ConnectionConfig, default=ConnectionConfig) - - -class GO2Connection(Module, Camera, Pointcloud): - config: ConnectionConfig - cmd_vel: In[Twist] - pointcloud: Out[PointCloud2] - odom: Out[PoseStamped] - lidar: Out[PointCloud2] - color_image: Out[Image] - camera_info: Out[CameraInfo] - - connection: Go2ConnectionProtocol - camera_info_static: CameraInfo = _camera_info_static() - _camera_info_thread: Thread | None = None - _latest_video_frame: Image | None = None - - @classmethod - def rerun_views(cls): # type: ignore[no-untyped-def] - """Return Rerun view blueprints for GO2 camera visualization.""" - return [ - rrb.Spatial2DView( - name="Camera", - origin="world/robot/camera/rgb", - ), - ] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.connection = make_connection(self.config.ip, self.config.g) - - if hasattr(self.connection, "camera_info_static"): - self.camera_info_static = self.connection.camera_info_static - - @rpc - def record(self, recording_name: str) -> None: - lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") # type: ignore[type-arg] - lidar_store.consume_stream(self.connection.lidar_stream()) - - odom_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/odom") # type: ignore[type-arg] - odom_store.consume_stream(self.connection.odom_stream()) - - video_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/video") # type: ignore[type-arg] - video_store.consume_stream(self.connection.video_stream()) - - @rpc - def start(self) -> None: - super().start() - if not hasattr(self, "connection"): - return - self.connection.start() - - def onimage(image: Image) -> None: - self.color_image.publish(image) - self._latest_video_frame = image - - self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self.register_disposable(self.connection.odom_stream().subscribe(self._publish_tf)) - self.register_disposable(self.connection.video_stream().subscribe(onimage)) - self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) - - self._camera_info_thread = Thread( - target=self.publish_camera_info, - daemon=True, - ) - self._camera_info_thread.start() - - self.standup() - time.sleep(3) - self.connection.balance_stand() - - if self.config.mode == Go2Mode.RAGE: - self.connection.enable_rage_mode() - - self.connection.set_obstacle_avoidance(self.config.g.obstacle_avoidance) - - # self.record("go2_bigoffice") - - @rpc - def stop(self) -> None: - self.liedown() - - if self.connection: - self.connection.stop() - - if self._camera_info_thread and self._camera_info_thread.is_alive(): - self._camera_info_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - - super().stop() - - @classmethod - def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=odom.ts, - ) - - camera_optical = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", - ts=odom.ts, - ) - - return [ - Transform.from_pose("base_link", odom), - camera_link, - camera_optical, - ] - - def _publish_tf(self, msg: PoseStamped) -> None: - transforms = self._odom_to_tf(msg) - self.tf.publish(*transforms) - if self.odom.transport: - self.odom.publish(msg) - - def publish_camera_info(self) -> None: - while True: - self.camera_info.publish(self.camera_info_static) - time.sleep(1.0) - - @rpc - def move(self, twist: Twist, duration: float = 0.0) -> bool: - """Send movement command to robot.""" - return self.connection.move(twist, duration) - - @rpc - def standup(self) -> bool: - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self) -> bool: - """Make the robot lie down.""" - return self.connection.liedown() - - @rpc - def balance_stand(self) -> bool: - """Enter BalanceStand: neutral state for switching locomotion modes""" - return self.connection.balance_stand() - - @rpc - def enable_rage_mode(self) -> bool: - """Enable Rage Mode (~2.5 m/s forward velocity envelope). - Ensures BalanceStand precondition regardless of current FSM state. - """ - self.connection.balance_stand() - time.sleep(0.3) - result = self.connection.enable_rage_mode() - logger.info("Rage Mode enabled") - return result - - @rpc - def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: - """Publish a request to the WebRTC connection. - Args: - topic: The RTC topic to publish to - data: The data dictionary to publish - Returns: - The result of the publish request - """ - return self.connection.publish_request(topic, data) - - @skill - def observe(self) -> Image | None: - """Returns the latest video frame from the robot camera. Use this skill for any visual world queries. - - This skill provides the current camera view for perception tasks. - Returns None if no frame has been captured yet. - """ - return self._latest_video_frame - - -def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy": - from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE - - connection = dimos.deploy(GO2Connection, ip=ip) - - connection.pointcloud.transport = pSHMTransport( - f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - connection.color_image.transport = pSHMTransport( - f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - - connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", Twist) - - connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) - connection.start() - - return connection diff --git a/dimos/robot/unitree/go2/connection_mujoco.py b/dimos/robot/unitree/go2/connection_mujoco.py new file mode 100644 index 0000000000..91a6ceb987 --- /dev/null +++ b/dimos/robot/unitree/go2/connection_mujoco.py @@ -0,0 +1,97 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo-simulated Go2 connection. + +A subclass of `MujocoConnectionBase` with Go2-specific camera mounting, +start/stop sequencing, robot RPC stubs, the perception protocols, and the +`observe()` skill. +""" + +from __future__ import annotations + +import time + +import rerun.blueprint as rrb + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.robot.connection_registry import connection +from dimos.robot.unitree.go2.config import ConnectionConfig, Go2Mode +from dimos.robot.unitree.mujoco_connection import MujocoConnectionBase +from dimos.spec.perception import Camera, Pointcloud + + +@connection(robot="go2", backend="mujoco") +class Go2MujocoConnection(MujocoConnectionBase, Camera, Pointcloud): + """MuJoCo-simulated Go2 connection.""" + + config: ConnectionConfig + pointcloud: Out[PointCloud2] + + _camera_link_offset: Vector3 = Vector3(0.3, 0.0, 0.0) + + @classmethod + def rerun_views(cls): # type: ignore[no-untyped-def] + return [ + rrb.Spatial2DView( + name="Camera", + origin="world/robot/camera/rgb", + ), + ] + + def _on_start(self) -> None: + self.standup() + time.sleep(3) + self.balance_stand() + + if self.config.mode == Go2Mode.RAGE: + self.enable_rage_mode() + + self.set_obstacle_avoidance(self.config.g.obstacle_avoidance) + + def _on_stop(self) -> None: + self.liedown() + + @rpc + def standup(self) -> bool: + return True + + @rpc + def liedown(self) -> bool: + return True + + @rpc + def balance_stand(self) -> bool: + return True + + @rpc + def enable_rage_mode(self) -> bool: + return True + + def set_obstacle_avoidance(self, enabled: bool = True) -> None: + pass + + @skill + def observe(self) -> Image | None: + """Returns the latest video frame from the robot camera. Use this skill for any visual world queries. + + This skill provides the current camera view for perception tasks. + Returns None if no frame has been captured yet. + """ + return self._latest_video_frame diff --git a/dimos/robot/unitree/go2/connection_replay.py b/dimos/robot/unitree/go2/connection_replay.py new file mode 100644 index 0000000000..d1ebeede1c --- /dev/null +++ b/dimos/robot/unitree/go2/connection_replay.py @@ -0,0 +1,166 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Go2 connection that replays sensor streams from a recorded dataset. + +Self-contained: no transport (datasets are read in-process), stream wiring, +and stubbed RPC surface in one Module class. No shared base. +""" + +from __future__ import annotations + +from threading import Thread +import time +from typing import Any + +from reactivex.disposable import Disposable +from reactivex.observable import Observable +import rerun.blueprint as rrb + +from dimos.agents.annotation import skill +from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.robot.connection_registry import connection +from dimos.robot.tf_utils import odom_to_tf +from dimos.robot.unitree.go2.camera import _camera_info_static +from dimos.robot.unitree.go2.config import ConnectionConfig +from dimos.spec.perception import Camera, Pointcloud +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing.replay import TimedSensorReplay + +logger = setup_logger() + + +@connection(robot="go2", backend="replay") +class Go2ReplayConnection(Module, Camera, Pointcloud): + """Go2 connection that replays a previously recorded dataset.""" + + config: ConnectionConfig + cmd_vel: In[Twist] + pointcloud: Out[PointCloud2] + odom: Out[PoseStamped] + lidar: Out[PointCloud2] + color_image: Out[Image] + camera_info: Out[CameraInfo] + + camera_info_static: CameraInfo = _camera_info_static() + _camera_info_thread: Thread | None = None + _latest_video_frame: Image | None = None + + @classmethod + def rerun_views(cls): # type: ignore[no-untyped-def] + return [ + rrb.Spatial2DView( + name="Camera", + origin="world/robot/camera/rgb", + ), + ] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.dataset = self.config.g.replay_db + self.replay_config: dict[str, Any] = { + "loop": True, + "seek": None, + "duration": None, + } + + @rpc + def start(self) -> None: + super().start() + + def onimage(image: Image) -> None: + self.color_image.publish(image) + self._latest_video_frame = image + + self.register_disposable(self._lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self._odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self._video_stream().subscribe(onimage)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) + + self._camera_info_thread = Thread( + target=self._publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + @rpc + def stop(self) -> None: + if self._camera_info_thread and self._camera_info_thread.is_alive(): + self._camera_info_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) + super().stop() + + @simple_mcache + def _lidar_stream(self) -> Observable[PointCloud2]: + store: TimedSensorReplay[PointCloud2] = TimedSensorReplay(f"{self.dataset}/lidar") + return store.stream(**self.replay_config) + + @simple_mcache + def _odom_stream(self) -> Observable[PoseStamped]: + store: TimedSensorReplay[PoseStamped] = TimedSensorReplay(f"{self.dataset}/odom") + return store.stream(**self.replay_config) + + @simple_mcache + def _video_stream(self) -> Observable[Image]: + store: TimedSensorReplay[Image] = TimedSensorReplay(f"{self.dataset}/color_image") + return store.stream(**self.replay_config) + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> bool: + return True + + @rpc + def standup(self) -> bool: + return True + + @rpc + def liedown(self) -> bool: + return True + + @rpc + def balance_stand(self) -> bool: + return True + + @rpc + def enable_rage_mode(self) -> bool: + return True + + def set_obstacle_avoidance(self, enabled: bool = True) -> None: + pass + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + return {"status": "ok", "message": "Fake publish"} + + def _publish_tf(self, msg: PoseStamped) -> None: + self.tf.publish(*odom_to_tf(msg)) + self.odom.publish(msg) + + def _publish_camera_info(self) -> None: + while True: + self.camera_info.publish(self.camera_info_static) + time.sleep(1.0) + + @skill + def observe(self) -> Image | None: + """Returns the latest video frame from the robot camera.""" + return self._latest_video_frame diff --git a/dimos/robot/unitree/go2/connection_webrtc.py b/dimos/robot/unitree/go2/connection_webrtc.py new file mode 100644 index 0000000000..ba5a336c56 --- /dev/null +++ b/dimos/robot/unitree/go2/connection_webrtc.py @@ -0,0 +1,322 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real-hardware Go2 connection over Unitree's WebRTC stack. + +Composes `UnitreeWebRtcSession` for the asyncio loop / handshake / move / +publish_request. This file adds Go2-specific stream wiring (lidar, odom, +video, camera_info), sport-mode RPCs, and TF publishing. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from threading import Event, Thread +import time +from typing import Any, TypeAlias + +import numpy as np +from numpy.typing import NDArray +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.subject import Subject +import rerun.blueprint as rrb +from unitree_webrtc_connect.constants import RTC_TOPIC, SPORT_CMD + +from dimos.agents.annotation import skill +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.robot.connection_registry import connection +from dimos.robot.tf_utils import odom_to_tf +from dimos.robot.unitree.go2.camera import _camera_info_static +from dimos.robot.unitree.go2.config import ConnectionConfig, Go2Mode +from dimos.robot.unitree.type.lidar import ( + pointcloud2_from_webrtc_lidar, + repair_stale_ts, +) +from dimos.robot.unitree.type.odometry import Odometry as OdometryConverter +from dimos.robot.unitree.webrtc_session import UnitreeWebRtcSession +from dimos.spec.perception import Camera, Pointcloud +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.reactive import backpressure + +VideoMessage: TypeAlias = NDArray[np.uint8] + +logger = setup_logger() + + +@dataclass +class SerializableVideoFrame: + """Pickleable wrapper for av.VideoFrame with all metadata.""" + + data: np.ndarray # type: ignore[type-arg] + pts: int | None = None + time: float | None = None + dts: int | None = None + width: int | None = None + height: int | None = None + format: str | None = None + + @classmethod + def from_av_frame(cls, frame: Any) -> SerializableVideoFrame: + return cls( + data=frame.to_ndarray(format="rgb24"), + pts=frame.pts, + time=frame.time, + dts=frame.dts, + width=frame.width, + height=frame.height, + format=frame.format.name if hasattr(frame, "format") and frame.format else None, + ) + + def to_ndarray(self, format: str | None = None) -> np.ndarray: # type: ignore[type-arg] + return self.data + + +_SPORT_API_ID_RAGEMODE: int = 2059 + + +@connection(robot="go2", backend="webrtc") +class Go2WebRtcConnection(Module, Camera, Pointcloud): + """Real-hardware Go2 connection over Unitree's WebRTC stack.""" + + config: ConnectionConfig + cmd_vel: In[Twist] + pointcloud: Out[PointCloud2] + odom: Out[PoseStamped] + lidar: Out[PointCloud2] + color_image: Out[Image] + camera_info: Out[CameraInfo] + + camera_info_static: CameraInfo = _camera_info_static() + _camera_info_thread: Thread | None = None + _latest_video_frame: Image | None = None + + @classmethod + def rerun_views(cls): # type: ignore[no-untyped-def] + return [ + rrb.Spatial2DView( + name="Camera", + origin="world/robot/camera/rgb", + ), + ] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.session = UnitreeWebRtcSession(self.config.ip) + + @rpc + def start(self) -> None: + super().start() + self.session.start() + + def onimage(image: Image) -> None: + self.color_image.publish(image) + self._latest_video_frame = image + + self.register_disposable(self._lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self._odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self._video_stream().subscribe(onimage)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) + + self._camera_info_thread = Thread( + target=self._publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + self.standup() + time.sleep(3) + self.balance_stand() + + if self.config.mode == Go2Mode.RAGE: + self.enable_rage_mode() + + self.set_obstacle_avoidance(self.config.g.obstacle_avoidance) + + @rpc + def stop(self) -> None: + self.liedown() + self.session.stop() + + from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT + + if self._camera_info_thread and self._camera_info_thread.is_alive(): + self._camera_info_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) + + super().stop() + + @simple_mcache + def _raw_lidar_stream(self) -> Observable: # type: ignore[type-arg] + return backpressure(self.session.sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @simple_mcache + def _raw_odom_stream(self) -> Observable: # type: ignore[type-arg] + return backpressure(self.session.sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @simple_mcache + def _lidar_stream(self) -> Observable[PointCloud2]: + return backpressure( + self._raw_lidar_stream().pipe( + ops.map(pointcloud2_from_webrtc_lidar), + repair_stale_ts(), + ) + ) + + @simple_mcache + def _odom_stream(self) -> Observable[PoseStamped]: + return backpressure(self._raw_odom_stream().pipe(ops.map(OdometryConverter.from_msg))) + + @simple_mcache + def _video_stream(self) -> Observable[Image]: + return backpressure( + self._raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] + format=ImageFormat.RGB, + frame_id="camera_optical", + ) + ), + ) + ) + + @simple_mcache + def _raw_video_stream(self) -> Observable[SerializableVideoFrame]: + subject: Subject[SerializableVideoFrame] = Subject() + stop_event = Event() + + from aiortc import MediaStreamTrack + + conn = self.session.conn + loop = self.session.loop + + async def accept_track(track: MediaStreamTrack) -> None: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + serializable_frame = SerializableVideoFrame.from_av_frame(frame) + subject.on_next(serializable_frame) + + conn.video.add_track_callback(accept_track) + + def switch_video_channel() -> None: + conn.video.switchVideoChannel(True) + + loop.call_soon_threadsafe(switch_video_channel) + + def stop() -> None: + stop_event.set() + conn.video.track_callbacks.remove(accept_track) + + def switch_video_channel_off() -> None: + conn.video.switchVideoChannel(False) + + loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to robot.""" + return self.session.move(twist, duration) + + @rpc + def standup(self) -> bool: + """Make the robot stand up.""" + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + ) + + @rpc + def liedown(self) -> bool: + """Make the robot lie down.""" + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + ) + + @rpc + def balance_stand(self) -> bool: + """Enter BalanceStand: neutral state for switching locomotion modes.""" + return bool( + self.session.publish_request( + RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]} + ) + ) + + @rpc + def enable_rage_mode(self) -> bool: + """Enable Rage Mode (~2.5 m/s forward velocity envelope). + + Ensures BalanceStand precondition regardless of current FSM state. + """ + # Force BalanceStand first. + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + time.sleep(0.3) + + rage_ok = bool( + self.session.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": _SPORT_API_ID_RAGEMODE, "parameter": {"data": True}}, + ) + ) + time.sleep(2.0) + + joystick_ok = bool( + self.session.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["SwitchJoystick"], "parameter": {"data": True}}, + ) + ) + logger.info("Rage Mode enabled") + return rage_ok and joystick_ok + + def set_obstacle_avoidance(self, enabled: bool = True) -> None: + self.session.publish_request( + RTC_TOPIC["OBSTACLES_AVOID"], + {"api_id": 1001, "parameter": {"enable": int(enabled)}}, + ) + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + """Publish a request to the underlying connection.""" + return self.session.publish_request(topic, data) # type: ignore[no-any-return] + + def _publish_tf(self, msg: PoseStamped) -> None: + self.tf.publish(*odom_to_tf(msg)) + self.odom.publish(msg) + + def _publish_camera_info(self) -> None: + while True: + self.camera_info.publish(self.camera_info_static) + time.sleep(1.0) + + @skill + def observe(self) -> Image | None: + """Returns the latest video frame from the robot camera. Use this skill for any visual world queries. + + This skill provides the current camera view for perception tasks. + Returns None if no frame has been captured yet. + """ + return self._latest_video_frame diff --git a/dimos/robot/unitree/go2/fleet_connection.py b/dimos/robot/unitree/go2/fleet_connection.py index f2a0216ab7..a938a5d51b 100644 --- a/dimos/robot/unitree/go2/fleet_connection.py +++ b/dimos/robot/unitree/go2/fleet_connection.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Go2 Fleet Connection - manage multiple Go2 robots as a fleet""" +"""Go2 Fleet Connection — manage multiple Go2 robots as a fleet. + +The primary robot uses the full Go2WebRtcConnection (sensors + RPCs). +Additional robots use a minimal command-only client (no sensor streams), +composing `UnitreeWebRtcSession` for transport. +""" from __future__ import annotations @@ -21,14 +26,12 @@ from typing import TYPE_CHECKING, Any from pydantic import Field, model_validator +from unitree_webrtc_connect.constants import RTC_TOPIC, SPORT_CMD from dimos.core.core import rpc -from dimos.robot.unitree.go2.connection import ( - ConnectionConfig, - GO2Connection, - Go2ConnectionProtocol, - make_connection, -) +from dimos.robot.unitree.go2.config import ConnectionConfig +from dimos.robot.unitree.go2.connection_webrtc import Go2WebRtcConnection +from dimos.robot.unitree.webrtc_session import UnitreeWebRtcSession from dimos.utils.logging_config import setup_logger if sys.version_info >= (3, 11): @@ -54,10 +57,60 @@ def set_ip_after_validation(self) -> Self: return self -class Go2FleetConnection(GO2Connection): - """Inherits all single-robot behaviour from GO2Connection for the primary - (first) robot. Additional robots only receive broadcast commands - (move, standup, liedown, publish_request). +class _FleetMemberClient: + """Command-only WebRTC client for extra fleet robots. + + Wraps a `UnitreeWebRtcSession` and adds the Go2 sport-mode commands + (standup/liedown/balance_stand/set_obstacle_avoidance). No sensor + streams — fleet does not subscribe to extras. + """ + + def __init__(self, ip: str) -> None: + self.session = UnitreeWebRtcSession(ip) + + def start(self) -> None: + self.session.start() + + def stop(self) -> None: + self.session.stop() + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + return self.session.move(twist, duration) + + def standup(self) -> bool: + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + ) + + def liedown(self) -> bool: + return bool( + self.session.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + ) + + def balance_stand(self) -> bool: + return bool( + self.session.publish_request( + RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]} + ) + ) + + def set_obstacle_avoidance(self, enabled: bool = True) -> None: + self.session.publish_request( + RTC_TOPIC["OBSTACLES_AVOID"], + {"api_id": 1001, "parameter": {"enable": int(enabled)}}, + ) + + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + return self.session.publish_request(topic, data) # type: ignore[no-any-return] + + +class Go2FleetConnection(Go2WebRtcConnection): + """Inherits all single-robot behaviour from Go2WebRtcConnection for the + primary (first) robot. Additional robots only receive broadcast commands + (move, standup, liedown, balance_stand, set_obstacle_avoidance, + publish_request) via _FleetMemberClient. + + Fleets are real-hardware only — there's no sim/replay equivalent. """ config: FleetConnectionConfig @@ -65,47 +118,43 @@ class Go2FleetConnection(GO2Connection): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._extra_ips = self.config.ips[1:] - self._extra_connections: list[Go2ConnectionProtocol] = [] + self._extra_connections: list[_FleetMemberClient] = [] @rpc def start(self) -> None: self._extra_connections.clear() for ip in self._extra_ips: - conn = make_connection(ip, self.config.g) - conn.start() - self._extra_connections.append(conn) + client = _FleetMemberClient(ip) + client.start() + self._extra_connections.append(client) - # Parent starts primary robot, subscribes sensors, calls standup() on all + # Parent starts primary robot, subscribes sensors, calls standup() on it. super().start() - for conn in self._extra_connections: - conn.balance_stand() - conn.set_obstacle_avoidance(self.config.g.obstacle_avoidance) + for client in self._extra_connections: + client.balance_stand() + client.set_obstacle_avoidance(self.config.g.obstacle_avoidance) @rpc def stop(self) -> None: - # one robot's error should not prevent others from stopping - for conn in self._extra_connections: + # One robot's error must not prevent others from stopping. + for client in self._extra_connections: try: - conn.liedown() + client.liedown() except Exception as e: logger.error(f"Error lying down fleet Go2: {e}") try: - conn.stop() + client.stop() except Exception as e: logger.error(f"Error stopping fleet Go2: {e}") self._extra_connections.clear() super().stop() - @property - def _all_connections(self) -> list[Go2ConnectionProtocol]: - return [self.connection, *self._extra_connections] - @rpc def move(self, twist: Twist, duration: float = 0.0) -> bool: - results: list[bool] = [] - for conn in self._all_connections: + results: list[bool] = [super().move(twist, duration)] + for client in self._extra_connections: try: - results.append(conn.move(twist, duration)) + results.append(client.move(twist, duration)) except Exception as e: logger.error(f"Fleet move failed: {e}") results.append(False) @@ -113,10 +162,10 @@ def move(self, twist: Twist, duration: float = 0.0) -> bool: @rpc def standup(self) -> bool: - results: list[bool] = [] - for conn in self._all_connections: + results: list[bool] = [super().standup()] + for client in self._extra_connections: try: - results.append(conn.standup()) + results.append(client.standup()) except Exception as e: logger.error(f"Fleet standup failed: {e}") results.append(False) @@ -124,10 +173,10 @@ def standup(self) -> bool: @rpc def liedown(self) -> bool: - results: list[bool] = [] - for conn in self._all_connections: + results: list[bool] = [super().liedown()] + for client in self._extra_connections: try: - results.append(conn.liedown()) + results.append(client.liedown()) except Exception as e: logger.error(f"Fleet liedown failed: {e}") results.append(False) @@ -136,9 +185,9 @@ def liedown(self) -> bool: @rpc def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: """Publish a request to all robots, return primary's response.""" - for conn in self._extra_connections: + for client in self._extra_connections: try: - conn.publish_request(topic, data) + client.publish_request(topic, data) except Exception as e: logger.error(f"Fleet publish_request failed: {e}") - return self.connection.publish_request(topic, data) + return super().publish_request(topic, data) diff --git a/dimos/robot/unitree/mujoco_camera_constants.py b/dimos/robot/unitree/mujoco_camera_constants.py new file mode 100644 index 0000000000..0edff84d2c --- /dev/null +++ b/dimos/robot/unitree/mujoco_camera_constants.py @@ -0,0 +1,51 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MuJoCo sim camera intrinsics constant, shared by sim connection modules +(Go2MujocoConnection, G1MujocoConnection) and by external readers that need +the value without pulling in the full mujoco transport class. +""" + +from __future__ import annotations + +import math + +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.simulation.mujoco.constants import ( + VIDEO_CAMERA_FOV, + VIDEO_HEIGHT, + VIDEO_WIDTH, +) + + +def _compute_mujoco_camera_info() -> CameraInfo: + """Pinhole model: f = height / (2 * tan(fovy / 2)).""" + fovy = math.radians(VIDEO_CAMERA_FOV) + f = VIDEO_HEIGHT / (2 * math.tan(fovy / 2)) + cx = VIDEO_WIDTH / 2.0 + cy = VIDEO_HEIGHT / 2.0 + + return CameraInfo( + frame_id="camera_optical", + height=VIDEO_HEIGHT, + width=VIDEO_WIDTH, + distortion_model="plumb_bob", + D=[0.0, 0.0, 0.0, 0.0, 0.0], + K=[f, 0.0, cx, 0.0, f, cy, 0.0, 0.0, 1.0], + R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + P=[f, 0.0, cx, 0.0, 0.0, f, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + ) + + +MUJOCO_CAMERA_INFO_STATIC: CameraInfo = _compute_mujoco_camera_info() diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 1e3b77fe1a..ff28ef37a0 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # Copyright 2025-2026 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,128 +12,165 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Shared base for MuJoCo-simulated Unitree connections. + +Holds the MuJoCo subprocess + shared-memory IPC, stream wiring, and command +dispatch that every sim backend needs. Robot-specific subclasses customize +the config type, extra ports, start/stop sequencing, camera mounting offset, +and any extra TF transforms. +""" + +from __future__ import annotations import atexit import base64 from collections.abc import Callable import functools import json -import math import os from pathlib import Path import pickle import subprocess import sys import sysconfig -import threading +from threading import Event, Thread, Timer import time from typing import Any, TypeVar import weakref import numpy as np -from numpy.typing import NDArray from reactivex import Observable from reactivex.abc import ObserverBase, SchedulerBase from reactivex.disposable import Disposable from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT -from dimos.core.global_config import GlobalConfig +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.robot.unitree.type.odometry import Odometry -from dimos.simulation.mujoco.constants import ( - LAUNCHER_PATH, - LIDAR_FPS, - VIDEO_CAMERA_FOV, - VIDEO_FPS, - VIDEO_HEIGHT, - VIDEO_WIDTH, -) -from dimos.simulation.mujoco.shared_memory import ShmWriter +from dimos.robot.tf_utils import odom_to_tf +from dimos.robot.unitree.mujoco_camera_constants import MUJOCO_CAMERA_INFO_STATIC from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -ODOM_FREQUENCY = 50 - logger = setup_logger() T = TypeVar("T") +_ODOM_FREQUENCY = 50 + -class MujocoConnection: - """MuJoCo simulator connection that runs in a separate subprocess.""" +class MujocoConnectionBase(Module): + """Shared MuJoCo subprocess + stream plumbing for Unitree sim connections.""" + + cmd_vel: In[Twist] + lidar: Out[PointCloud2] + odom: Out[PoseStamped] + color_image: Out[Image] + camera_info: Out[CameraInfo] + + camera_info_static: CameraInfo = MUJOCO_CAMERA_INFO_STATIC + _camera_info_thread: Thread | None = None + _latest_video_frame: Image | None = None + + # Translation from base_link to camera_link. Subclasses override to match + # their robot's camera mounting position. + _camera_link_offset: Vector3 = Vector3(0.0, 0.0, 0.0) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - def __init__(self, global_config: GlobalConfig) -> None: try: import mujoco # noqa: F401 except ImportError: raise ImportError("'mujoco' is not installed. Use `pip install -e .[sim]`") - # Pre-download the mujoco_sim data. get_data("mujoco_sim") - # Trigger the download of the mujoco_menagerie package. This is so it - # doesn't trigger in the mujoco process where it can time out. + # Trigger menagerie download outside the subprocess to avoid timeout there. from mujoco_playground._src import mjx_env mjx_env.ensure_menagerie_exists() - self.global_config = global_config self.process: subprocess.Popen[bytes] | None = None - self.shm_data: ShmWriter | None = None + self.shm_data: Any = None # ShmWriter, lazily imported self._last_video_seq = 0 self._last_odom_seq = 0 self._last_lidar_seq = 0 - self._stop_timer: threading.Timer | None = None + self._cmd_stop_timer: Timer | None = None - self._stream_threads: list[threading.Thread] = [] - self._stop_events: list[threading.Event] = [] + self._stream_threads: list[Thread] = [] + self._stop_events: list[Event] = [] self._is_cleaned_up = False + self._stop_event = Event() + + @rpc + def start(self) -> None: + super().start() + self._start_subprocess() + + def onimage(image: Image) -> None: + self.color_image.publish(image) + self._latest_video_frame = image + + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(self._odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self._lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self._video_stream().subscribe(onimage)) - @staticmethod - def _compute_camera_info() -> CameraInfo: - """Compute camera intrinsics from MuJoCo camera parameters. - - Uses pinhole camera model: f = height / (2 * tan(fovy / 2)) - """ - fovy = math.radians(VIDEO_CAMERA_FOV) - f = VIDEO_HEIGHT / (2 * math.tan(fovy / 2)) - cx = VIDEO_WIDTH / 2.0 - cy = VIDEO_HEIGHT / 2.0 - - return CameraInfo( - frame_id="camera_optical", - height=VIDEO_HEIGHT, - width=VIDEO_WIDTH, - distortion_model="plumb_bob", - D=[0.0, 0.0, 0.0, 0.0, 0.0], - K=[f, 0.0, cx, 0.0, f, cy, 0.0, 0.0, 1.0], - R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - P=[f, 0.0, cx, 0.0, 0.0, f, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + self._camera_info_thread = Thread( + target=self._publish_camera_info_loop, + daemon=True, ) + self._camera_info_thread.start() - camera_info_static: CameraInfo = _compute_camera_info() + self._on_start() + + @rpc + def stop(self) -> None: + self._on_stop() + self._stop_event.set() + self._teardown_subprocess() + + if self._camera_info_thread and self._camera_info_thread.is_alive(): + self._camera_info_thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) + + super().stop() + + def _on_start(self) -> None: + """Subclass hook: extra setup after streams are wired (e.g., standup).""" + + def _on_stop(self) -> None: + """Subclass hook: cleanup before subprocess teardown (e.g., liedown).""" + + def _extra_transforms(self, msg: PoseStamped) -> list[Transform]: + """Subclass hook: extra TF transforms to publish alongside the standard ones.""" + return [] + + def _start_subprocess(self) -> None: + from dimos.simulation.mujoco.constants import LAUNCHER_PATH + from dimos.simulation.mujoco.shared_memory import ShmWriter - def start(self) -> None: self.shm_data = ShmWriter() - config_pickle = base64.b64encode(pickle.dumps(self.global_config)).decode("ascii") + config_pickle = base64.b64encode(pickle.dumps(self.config.g)).decode("ascii") shm_names_json = json.dumps(self.shm_data.shm.to_names()) - # Launch the subprocess try: - # mjpython must be used on macOS (because of launch_passive inside mujoco_process.py). + # mjpython must be used on macOS (because of launch_passive inside the subprocess). # It needs libpython on the dylib search path; uv-installed Pythons # use @rpath which doesn't always resolve inside venvs, so we # point DYLD_LIBRARY_PATH at the real libpython directory. executable = sys.executable if sys.platform != "darwin" else "mjpython" env = os.environ.copy() if sys.platform == "darwin": - # on some systems mujoco looks in the wrong place for shared libraries. So we force it look in the right place libdir = Path(sysconfig.get_config_var("LIBDIR") or "") if libdir.is_dir(): existing = env.get("DYLD_LIBRARY_PATH", "") @@ -146,75 +181,64 @@ def start(self) -> None: stderr=subprocess.PIPE, env=env, ) - except Exception as e: self.shm_data.cleanup() raise RuntimeError(f"Failed to start MuJoCo subprocess: {e}") from e - # Wait for process to be ready ready_timeout = 300.0 start_time = time.time() assert self.process is not None while time.time() - start_time < ready_timeout: if self.process.poll() is not None: exit_code = self.process.returncode - self.stop() + self._teardown_subprocess() raise RuntimeError(f"MuJoCo process failed to start (exit code {exit_code})") if self.shm_data.is_ready(): logger.info("MuJoCo process started successfully") - # Register atexit handler to ensure subprocess is cleaned up - # Use weakref to avoid preventing garbage collection weak_self = weakref.ref(self) def cleanup_on_exit( - weak_self: "weakref.ReferenceType[MujocoConnection]" = weak_self, + weak_self: weakref.ReferenceType[MujocoConnectionBase] = weak_self, ) -> None: instance = weak_self() if instance is not None: - instance.stop() + instance._teardown_subprocess() atexit.register(cleanup_on_exit) return time.sleep(0.1) - # Timeout - self.stop() + self._teardown_subprocess() raise RuntimeError("MuJoCo process failed to start (timeout)") - def stop(self) -> None: + def _teardown_subprocess(self) -> None: if self._is_cleaned_up: return self._is_cleaned_up = True - # clean up open file descriptors if self.process: if self.process.stderr: self.process.stderr.close() if self.process.stdout: self.process.stdout.close() - # Cancel any pending timers - if self._stop_timer: - self._stop_timer.cancel() - self._stop_timer = None + if self._cmd_stop_timer: + self._cmd_stop_timer.cancel() + self._cmd_stop_timer = None - # Stop all stream threads for stop_event in self._stop_events: stop_event.set() - # Wait for threads to finish for thread in self._stream_threads: if thread.is_alive(): thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) if thread.is_alive(): logger.warning(f"Stream thread {thread.name} did not stop gracefully") - # Signal subprocess to stop if self.shm_data: self.shm_data.signal_stop() - # Wait for process to finish if self.process: try: self.process.terminate() @@ -229,46 +253,29 @@ def stop(self) -> None: self.process = None - # Clean up shared memory if self.shm_data: self.shm_data.cleanup() self.shm_data = None - # Clear references self._stream_threads.clear() self._stop_events.clear() - self.lidar_stream.cache_clear() - self.odom_stream.cache_clear() - self.video_stream.cache_clear() - - def standup(self) -> bool: - return True - - def liedown(self) -> bool: - return True - - def balance_stand(self) -> bool: - return True + self._lidar_stream.cache_clear() + self._odom_stream.cache_clear() + self._video_stream.cache_clear() - def set_obstacle_avoidance(self, enabled: bool = True) -> None: - pass - - def enable_rage_mode(self) -> bool: - return True - - def get_video_frame(self) -> NDArray[Any] | None: + def _get_video_frame(self) -> np.ndarray | None: # type: ignore[type-arg] if self.shm_data is None: return None frame, seq = self.shm_data.read_video() if seq > self._last_video_seq: self._last_video_seq = seq - return frame + return frame # type: ignore[no-any-return] return None - def get_odom_message(self) -> Odometry | None: + def _get_odom_message(self) -> PoseStamped | None: if self.shm_data is None: return None @@ -277,10 +284,12 @@ def get_odom_message(self) -> Odometry | None: self._last_odom_seq = seq pos, quat_wxyz, timestamp = odom_data - # Convert quaternion from (w,x,y,z) to (x,y,z,w) for ROS/Dimos + # Convert quaternion from (w,x,y,z) to (x,y,z,w) for ROS/Dimos. + from dimos.robot.unitree.type.odometry import Odometry as OdometryMsg + orientation = Quaternion(quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]) - return Odometry( + return OdometryMsg( position=Vector3(pos[0], pos[1], pos[2]), orientation=orientation, ts=timestamp, @@ -289,14 +298,14 @@ def get_odom_message(self) -> Odometry | None: return None - def get_lidar_message(self) -> PointCloud2 | None: + def _get_lidar_message(self) -> PointCloud2 | None: if self.shm_data is None: return None lidar_msg, seq = self.shm_data.read_lidar() if seq > self._last_lidar_seq and lidar_msg is not None: self._last_lidar_seq = seq - return lidar_msg + return lidar_msg # type: ignore[no-any-return] return None @@ -311,7 +320,7 @@ def on_subscribe(observer: ObserverBase[T], _scheduler: SchedulerBase | None) -> observer.on_completed() return Disposable(lambda: None) - stop_event = threading.Event() + stop_event = Event() self._stop_events.append(stop_event) def run() -> None: @@ -326,7 +335,7 @@ def run() -> None: finally: observer.on_completed() - thread = threading.Thread(target=run, daemon=True) + thread = Thread(target=run, daemon=True) self._stream_threads.append(thread) thread.start() @@ -338,23 +347,44 @@ def dispose() -> None: return Observable(on_subscribe) @functools.cache - def lidar_stream(self) -> Observable[PointCloud2]: - return self._create_stream(self.get_lidar_message, LIDAR_FPS, "Lidar") + def _lidar_stream(self) -> Observable[PointCloud2]: + from dimos.simulation.mujoco.constants import LIDAR_FPS + + return self._create_stream(self._get_lidar_message, LIDAR_FPS, "Lidar") @functools.cache - def odom_stream(self) -> Observable[Odometry]: - return self._create_stream(self.get_odom_message, ODOM_FREQUENCY, "Odom") + def _odom_stream(self) -> Observable[PoseStamped]: + return self._create_stream(self._get_odom_message, _ODOM_FREQUENCY, "Odom") @functools.cache - def video_stream(self) -> Observable[Image]: + def _video_stream(self) -> Observable[Image]: + from dimos.simulation.mujoco.constants import VIDEO_FPS + def get_video_as_image() -> Image | None: - frame = self.get_video_frame() + frame = self._get_video_frame() # MuJoCo renderer returns RGB uint8 frames; Image.from_numpy defaults to BGR. return Image.from_numpy(frame, format=ImageFormat.RGB) if frame is not None else None return self._create_stream(get_video_as_image, VIDEO_FPS, "Video") + def _publish_camera_info_loop(self) -> None: + while not self._stop_event.is_set(): + self.camera_info.publish(self.camera_info_static) + self._stop_event.wait(1.0) + + def _publish_tf(self, msg: PoseStamped) -> None: + self.odom.publish(msg) + self.tf.publish( + *odom_to_tf( + msg, + camera_link_offset=self._camera_link_offset, + extras=self._extra_transforms(msg), + ) + ) + + @rpc def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to the sim via shared memory.""" if self._is_cleaned_up or self.shm_data is None: return True @@ -363,21 +393,22 @@ def move(self, twist: Twist, duration: float = 0.0) -> bool: self.shm_data.write_command(linear, angular) if duration > 0: - if self._stop_timer: - self._stop_timer.cancel() + if self._cmd_stop_timer: + self._cmd_stop_timer.cancel() def stop_movement() -> None: if self.shm_data: self.shm_data.write_command( np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32) ) - self._stop_timer = None + self._cmd_stop_timer = None - self._stop_timer = threading.Timer(duration, stop_movement) - self._stop_timer.daemon = True - self._stop_timer.start() + self._cmd_stop_timer = Timer(duration, stop_movement) + self._cmd_stop_timer.daemon = True + self._cmd_stop_timer.start() return True + @rpc def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: - print(f"publishing request, topic={topic}, data={data}") + logger.info(f"sim publish_request stub: topic={topic} data={data}") return {} diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index f98eba2c3c..37a24dc928 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -20,17 +20,14 @@ from reactivex import interval from reactivex.disposable import Disposable -from dimos.core.coordination.module_coordinator import ModuleCoordinator from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.core.transport import LCMTransport from dimos.mapping.pointclouds.accumulators.general import GeneralPointCloudAccumulator from dimos.mapping.pointclouds.accumulators.protocol import PointCloudAccumulator from dimos.mapping.pointclouds.occupancy import general_occupancy from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.robot.unitree.go2.connection import Go2ConnectionProtocol class MapConfig(ModuleConfig): @@ -109,12 +106,3 @@ def _publish(self, _: Any) -> None: occupancygrid = self._preloaded_occupancy self.global_costmap.publish(occupancygrid) - - -def deploy(dimos: ModuleCoordinator, connection: Go2ConnectionProtocol): # type: ignore[no-untyped-def] - mapper = dimos.deploy(Map, global_publish_interval=1.0) - mapper.global_map.transport = LCMTransport("/global_map", PointCloud2) - mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) - mapper.lidar.connect(connection.pointcloud) # type: ignore[attr-defined] - mapper.start() - return mapper diff --git a/dimos/robot/unitree/webrtc_session.py b/dimos/robot/unitree/webrtc_session.py new file mode 100644 index 0000000000..25b3439561 --- /dev/null +++ b/dimos/robot/unitree/webrtc_session.py @@ -0,0 +1,186 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared WebRTC session for Unitree robots. + +Owns the asyncio event loop, background thread, and LegionConnection +lifecycle for one robot. Used by composition from the per-robot Module +classes (Go2WebRtcConnection, G1WebRtcConnection) and from the fleet-member +helper. +""" + +from __future__ import annotations + +import asyncio +from threading import Event, Thread, Timer +import time +from typing import Any + +from unitree_webrtc_connect.constants import RTC_TOPIC +from unitree_webrtc_connect.webrtc_driver import ( + UnitreeWebRTCConnection as LegionConnection, + WebRTCConnectionMethod, +) + +from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.utils.logging_config import setup_logger +from dimos.utils.reactive import callback_to_observable + +logger = setup_logger() + + +class UnitreeWebRtcSession: + """Asyncio loop + LegionConnection lifecycle for one Unitree robot. + + Construction is cheap (no network). `start()` runs the WebRTC handshake + and blocks until ready. `stop()` disconnects cleanly and joins the + background thread. + + The underlying `loop` and `conn` are exposed for callers that need + direct asyncio/WebRTC access (e.g., video track subscription). + """ + + def __init__(self, ip: str, *, mode_name: str = "ai", cmd_vel_timeout: float = 0.2) -> None: + assert ip, "IP address must be provided" + self.ip = ip + self.mode_name = mode_name + self.cmd_vel_timeout = cmd_vel_timeout + + self.loop = asyncio.new_event_loop() + self.conn = LegionConnection(WebRTCConnectionMethod.LocalSTA, ip=ip) + + self._task: asyncio.Task[None] | None = None + self._thread: Thread | None = None + self._connection_ready = Event() + self._stop_timer: Timer | None = None + + def start(self) -> None: + async def async_connect() -> None: + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + self.conn.datachannel.set_decoder(decoder_type="native") + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], + {"api_id": 1002, "parameter": {"name": self.mode_name}}, + ) + self._connection_ready.set() + while True: + await asyncio.sleep(1) + + def start_background_loop() -> None: + asyncio.set_event_loop(self.loop) + self._task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self._thread = Thread(target=start_background_loop, daemon=True) + self._thread.start() + self._connection_ready.wait() + + def stop(self) -> None: + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + if self._task: + self._task.cancel() + + async def async_disconnect() -> None: + try: + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": 0, "ly": 0, "rx": 0, "ry": 0}, + ) + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + self.loop.call_soon_threadsafe(self.loop.stop) + + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) + + def _stop_movement(self) -> None: + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send a Twist as a WIRELESS_CONTROLLER command, auto-stopping after cmd_vel_timeout.""" + x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z + + # WebRTC coordinate mapping: + # x - positive right, negative left + # y - positive forward, negative backwards + # yaw - positive rotate right, negative rotate left + async def async_move() -> None: + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration() -> None: + start_time = time.time() + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(0.01) + + if self._stop_timer: + self._stop_timer.cancel() + + self._stop_timer = Timer(self.cmd_vel_timeout, self._stop_movement) + self._stop_timer.daemon = True + self._stop_timer.start() + + try: + if duration > 0: + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + self._stop_movement() + else: + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + logger.error(f"Failed to send movement command: {e}") + return False + + def publish_request(self, topic: str, data: dict[Any, Any]) -> Any: + """Synchronous wrapper around publish_request_new running on the session loop.""" + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + def sub_stream(self, topic_name: str): # type: ignore[no-untyped-def] + """Convert a Unitree pub/sub topic into an observable stream.""" + + def subscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] + def run_subscription() -> None: + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] + def run_unsubscription() -> None: + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) diff --git a/dimos/utils/testing/test_moment.py b/dimos/utils/testing/test_moment.py index dcca3d7d01..4dfbd12421 100644 --- a/dimos/utils/testing/test_moment.py +++ b/dimos/utils/testing/test_moment.py @@ -20,7 +20,8 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.protocol.tf.tf import TF -from dimos.robot.unitree.go2 import connection +from dimos.robot.tf_utils import odom_to_tf +from dimos.robot.unitree.go2.camera import _camera_info_static from dimos.utils.data import get_data from dimos.utils.testing.moment import Moment, SensorMoment @@ -46,14 +47,14 @@ def transforms(self) -> list[Transform]: # back and forth through time and foxglove doesn't get confused odom = self.odom.value odom.ts = time.time() - return connection.GO2Connection._odom_to_tf(odom) + return odom_to_tf(odom) def publish(self) -> None: t = TF() t.publish(*self.transforms) t.stop() - camera_info = connection._camera_info_static() + camera_info = _camera_info_static() camera_info.ts = time.time() camera_info_transport: LCMTransport[CameraInfo] = LCMTransport("/camera_info", CameraInfo) camera_info_transport.publish(camera_info)