diff --git a/src/reflex/cli.py b/src/reflex/cli.py index dc82c6e..527ef96 100644 --- a/src/reflex/cli.py +++ b/src/reflex/cli.py @@ -2142,7 +2142,19 @@ def ros2_serve( ), state_topic: str = typer.Option( "/joint_states", - help="sensor_msgs/JointState topic — .position field becomes the state vector", + help="State topic. Default reads .position from a JointState (arm " + "convention). For drones, pair --state-msg-type=odom with " + "/mavros/local_position/odom for full 10-DOF state, or --state-" + "msg-type=imu with /mavros/imu/data for 4-DOF orientation-only.", + ), + state_msg_type: str = typer.Option( + "joint_state", + "--state-msg-type", + help="How to interpret messages on --state-topic. One of: " + "'joint_state' (sensor_msgs/JointState .position, arms; default), " + "'imu' (sensor_msgs/Imu .orientation quaternion, drone partial " + "state — 4 DOF), 'odom' (nav_msgs/Odometry pose + linear twist, " + "drone full state — 10 DOF, matches the quadcopter preset).", ), task_topic: str = typer.Option( "/reflex/task", @@ -2208,7 +2220,7 @@ def ros2_serve( console.print(f" export_dir: {export_dir}") console.print(f" node_name: {node_name}") console.print(f" rate_hz: {rate_hz}") - console.print(f" subs: {image_topic}, {state_topic}, {task_topic}") + console.print(f" subs: {image_topic}, {state_topic} ({state_msg_type}), {task_topic}") console.print(f" pub: {action_topic}") if mcp: console.print(f" mcp: {mcp_transport}" + (f" (port {mcp_port})" if mcp_transport == "http" else "")) @@ -2223,10 +2235,16 @@ def ros2_serve( action_topic=action_topic, rate_hz=rate_hz, node_name=node_name, + state_msg_type=state_msg_type, mcp=mcp, mcp_transport=mcp_transport, mcp_port=mcp_port, ) + except ValueError as exc: + # _resolve_state_msg_class / create_ros2_bridge_node raise ValueError + # for unknown --state-msg-type values. + console.print(f"[red]{exc}[/red]") + raise typer.Exit(1) except ImportError as exc: console.print(f"[red]{exc}[/red]") raise typer.Exit(2) diff --git a/src/reflex/runtime/ros2_bridge.py b/src/reflex/runtime/ros2_bridge.py index a8a904b..c0e846a 100644 --- a/src/reflex/runtime/ros2_bridge.py +++ b/src/reflex/runtime/ros2_bridge.py @@ -13,10 +13,21 @@ Default topic layout (override via CLI flags): subs: /camera/image_raw sensor_msgs/msg/Image (rgb8) - /joint_states sensor_msgs/msg/JointState (positions → state vector) + /joint_states state topic (default JointState — see below) /reflex/task std_msgs/msg/String (text instruction) pub: /reflex/actions std_msgs/msg/Float32MultiArray (flat chunk × action_dim) + +State extractors (selected via --state-msg-type): + joint_state sensor_msgs/msg/JointState .position (arms; default) + imu sensor_msgs/msg/Imu .orientation (quat) (drone, partial) + odom nav_msgs/msg/Odometry pose + linear twist (drone, full state) + +Recommended drone topic: /mavros/local_position/odom (nav_msgs/Odometry) — +gives the policy position + orientation + linear velocity in one message, +matching the 10-DOF state shape used by the shipped quadcopter preset. +The IMU path only yields 4 DOF (orientation only) — useful as a fallback +when full odometry isn't available, but expect reduced control quality. """ from __future__ import annotations @@ -30,13 +41,19 @@ def _require_rclpy(): - """Import rclpy + message modules or raise a helpful ImportError.""" + """Import rclpy + always-needed message modules or raise a helpful ImportError. + + Returns the core types every bridge configuration needs. State-specific + message classes (Imu, Odometry) are resolved lazily via + `_resolve_state_msg` so a missing optional package doesn't break arm + deployments that only need JointState. + """ try: import rclpy from rclpy.node import Node - from sensor_msgs.msg import Image, JointState + from sensor_msgs.msg import Image from std_msgs.msg import Float32MultiArray, String - return rclpy, Node, Image, JointState, String, Float32MultiArray + return rclpy, Node, Image, String, Float32MultiArray except ImportError as exc: raise ImportError( "rclpy not available. The reflex ROS2 bridge requires a ROS2 install " @@ -48,6 +65,87 @@ def _require_rclpy(): ) from exc +# --------------------------------------------------------------------------- +# State extractors — one per supported state message type. Each takes a +# duck-typed message object and returns a flat list[float] state vector. +# +# These are deliberately decoupled from any rclpy/ROS2 imports so they can +# be unit-tested with simple SimpleNamespace mocks on machines without ROS2. +# --------------------------------------------------------------------------- + + +def _extract_joint_state(msg: Any) -> list[float]: + """sensor_msgs/JointState → joint positions (arm convention).""" + return [float(x) for x in msg.position] + + +def _extract_imu(msg: Any) -> list[float]: + """sensor_msgs/Imu → quaternion orientation [x, y, z, w] (4 DOF). + + Partial state — no position, no velocity. Use `odom` instead when you + have access to /mavros/local_position/odom for full drone state. + """ + o = msg.orientation + return [float(o.x), float(o.y), float(o.z), float(o.w)] + + +def _extract_odom(msg: Any) -> list[float]: + """nav_msgs/Odometry → pos (3) + quat orient (4) + linear vel (3) = 10 DOF. + + Matches the state shape declared in the shipped quadcopter preset. + The canonical drone topic is /mavros/local_position/odom. + """ + pose = msg.pose.pose + twist = msg.twist.twist + p = pose.position + o = pose.orientation + v = twist.linear + return [ + float(p.x), float(p.y), float(p.z), + float(o.x), float(o.y), float(o.z), float(o.w), + float(v.x), float(v.y), float(v.z), + ] + + +# State-msg-type -> (lazy msg class loader, extractor). The class loader is +# called only when the bridge is actually set up — keeps the optional ROS2 +# package deps out of the import path for unrelated codepaths. +_STATE_EXTRACTORS: dict[str, tuple[str, Any]] = { + "joint_state": ("joint_state", _extract_joint_state), + "imu": ("imu", _extract_imu), + "odom": ("odom", _extract_odom), +} + + +def _resolve_state_msg_class(state_msg_type: str) -> Any: + """Lazy-import the ROS2 message class for the requested state type. + + Raises ValueError on unknown type, ImportError if the relevant ROS2 + package isn't installed (with a hint about which apt/robostack package + provides it). + """ + if state_msg_type == "joint_state": + from sensor_msgs.msg import JointState + return JointState + if state_msg_type == "imu": + from sensor_msgs.msg import Imu + return Imu + if state_msg_type == "odom": + try: + from nav_msgs.msg import Odometry + except ImportError as exc: + raise ImportError( + "nav_msgs not installed. Required for --state-msg-type=odom. " + "On a ROS2 install, this ships with ros--nav-msgs " + "(apt) or the robostack ros--nav-msgs conda package." + ) from exc + return Odometry + raise ValueError( + f"unknown state_msg_type {state_msg_type!r}; expected one of: " + f"{', '.join(sorted(_STATE_EXTRACTORS))}" + ) + + def create_ros2_bridge_node( server: Any, *, @@ -57,6 +155,7 @@ def create_ros2_bridge_node( action_topic: str = "/reflex/actions", rate_hz: float = 20.0, node_name: str = "reflex_vla", + state_msg_type: str = "joint_state", ) -> Any: """Build a ROS2 node that wraps ``server.predict()`` as pub/sub. @@ -64,8 +163,27 @@ def create_ros2_bridge_node( latest message from each, and at ``rate_hz`` Hz invokes ``server.predict(image, instruction, state)`` and publishes the action chunk to ``action_topic`` as a flat Float32MultiArray. + + state_msg_type selects how the state vector is extracted: + - "joint_state": sensor_msgs/JointState.position (default; arms) + - "imu": sensor_msgs/Imu.orientation (drone; 4-DOF partial) + - "odom": nav_msgs/Odometry pose + twist (drone; 10-DOF full) """ - rclpy, Node, Image, JointState, String, Float32MultiArray = _require_rclpy() + rclpy, Node, Image, String, Float32MultiArray = _require_rclpy() + if state_msg_type not in _STATE_EXTRACTORS: + raise ValueError( + f"unknown state_msg_type {state_msg_type!r}; expected one of: " + f"{', '.join(sorted(_STATE_EXTRACTORS))}" + ) + StateMsgClass = _resolve_state_msg_class(state_msg_type) + _, state_extractor = _STATE_EXTRACTORS[state_msg_type] + + # If the loaded server has an embodiment config, surface a one-time + # warning when the extracted state vector length doesn't match what the + # embodiment expects. Silent mismatches caused real drone bugs (#121). + expected_state_dim = getattr( + getattr(server, "embodiment_config", None), "state_dim", None + ) class ReflexROS2Node(Node): """Bridge node. Also implements the `mcp.ros2_tools.ROS2Context` protocol @@ -78,9 +196,10 @@ def __init__(self) -> None: self._last_state: list[float] | None = None self._last_task: str = "" self._inference_count = 0 + self._state_dim_mismatch_warned = False self.create_subscription(Image, image_topic, self._image_cb, 10) - self.create_subscription(JointState, state_topic, self._state_cb, 10) + self.create_subscription(StateMsgClass, state_topic, self._state_cb, 10) self.create_subscription(String, task_topic, self._task_cb, 10) self._action_pub = self.create_publisher(Float32MultiArray, action_topic, 10) self._estop_pub = self.create_publisher(String, "/reflex/e_stop", 10) @@ -88,8 +207,8 @@ def __init__(self) -> None: self.get_logger().info( f"reflex ros2 node '{node_name}' up: subs={image_topic} + " - f"{state_topic} + {task_topic}, pub={action_topic} at " - f"{rate_hz:.1f} Hz" + f"{state_topic} ({state_msg_type}) + {task_topic}, " + f"pub={action_topic} at {rate_hz:.1f} Hz" ) def _image_cb(self, msg: Any) -> None: @@ -119,7 +238,33 @@ def _image_cb(self, msg: Any) -> None: ) def _state_cb(self, msg: Any) -> None: - self._last_state = [float(x) for x in msg.position] + try: + state = state_extractor(msg) + except (AttributeError, TypeError) as exc: + self.get_logger().error( + f"failed to extract state from {state_msg_type!r} message " + f"on {state_topic}: {exc}. Check --state-msg-type matches " + f"the actual message type being published." + ) + return + self._last_state = state + # One-time warning on state-dim mismatch — silent shape mismatches + # are the most common drone deployment failure mode (silently + # garbage actions instead of a loud error). Fire once per node. + if ( + expected_state_dim is not None + and not self._state_dim_mismatch_warned + and len(state) != expected_state_dim + ): + self.get_logger().warning( + f"state vector length {len(state)} from " + f"{state_msg_type!r} extractor does NOT match embodiment " + f"state_dim {expected_state_dim}. The policy will receive " + f"the wrong shape and likely produce nonsense actions. " + f"Either change --state-msg-type or load an embodiment " + f"config whose state_dim matches your robot's state." + ) + self._state_dim_mismatch_warned = True def _task_cb(self, msg: Any) -> None: self._last_task = str(msg.data) @@ -218,6 +363,7 @@ def run_ros2_bridge( action_topic: str = "/reflex/actions", rate_hz: float = 20.0, node_name: str = "reflex_vla", + state_msg_type: str = "joint_state", mcp: bool = False, mcp_transport: str = "stdio", mcp_port: int = 8001, @@ -237,7 +383,7 @@ def run_ros2_bridge( - mcp=True, transport="http": MCP runs in background thread on ``mcp_port``; ``rclpy.spin(node)`` blocks on main thread. """ - rclpy, _, _, _, _, _ = _require_rclpy() + rclpy, _, _, _, _ = _require_rclpy() from reflex.runtime.server import ReflexServer if mcp and mcp_transport not in ("stdio", "http"): @@ -266,6 +412,7 @@ def run_ros2_bridge( action_topic=action_topic, rate_hz=rate_hz, node_name=node_name, + state_msg_type=state_msg_type, ) if not mcp: diff --git a/tests/test_ros2_bridge.py b/tests/test_ros2_bridge.py index b4ab5e1..41ba2c3 100644 --- a/tests/test_ros2_bridge.py +++ b/tests/test_ros2_bridge.py @@ -100,6 +100,11 @@ def get_logger(self): sensor_msgs_msg = types.ModuleType("sensor_msgs.msg") sensor_msgs_msg.Image = type("Image", (), {}) sensor_msgs_msg.JointState = type("JointState", (), {}) + sensor_msgs_msg.Imu = type("Imu", (), {}) + + nav_msgs = types.ModuleType("nav_msgs") + nav_msgs_msg = types.ModuleType("nav_msgs.msg") + nav_msgs_msg.Odometry = type("Odometry", (), {}) std_msgs = types.ModuleType("std_msgs") std_msgs_msg = types.ModuleType("std_msgs.msg") @@ -114,6 +119,8 @@ def __init__(self): monkeypatch.setitem(sys.modules, "rclpy.node", rclpy_node) monkeypatch.setitem(sys.modules, "sensor_msgs", sensor_msgs) monkeypatch.setitem(sys.modules, "sensor_msgs.msg", sensor_msgs_msg) + monkeypatch.setitem(sys.modules, "nav_msgs", nav_msgs) + monkeypatch.setitem(sys.modules, "nav_msgs.msg", nav_msgs_msg) monkeypatch.setitem(sys.modules, "std_msgs", std_msgs) monkeypatch.setitem(sys.modules, "std_msgs.msg", std_msgs_msg) @@ -251,3 +258,215 @@ def test_tick_handles_server_error_gracefully(monkeypatch): # predict called but no publish happened server.predict.assert_called_once() node._action_pub.publish.assert_not_called() + + +# --------------------------------------------------------------------------- +# State extractor unit tests — exercise the decoupled extractor functions +# directly with SimpleNamespace mocks. No rclpy install needed. +# --------------------------------------------------------------------------- + +from types import SimpleNamespace # noqa: E402 + + +class TestJointStateExtractor: + def test_extracts_position_field(self): + from reflex.runtime.ros2_bridge import _extract_joint_state + msg = SimpleNamespace(position=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) + state = _extract_joint_state(msg) + assert state == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + def test_normalises_to_floats(self): + from reflex.runtime.ros2_bridge import _extract_joint_state + msg = SimpleNamespace(position=(1, 2, 3)) + state = _extract_joint_state(msg) + assert state == [1.0, 2.0, 3.0] + assert all(isinstance(v, float) for v in state) + + +class TestImuExtractor: + def test_extracts_quaternion_xyzw(self): + """ROS REP-103 convention — quaternion stored as (x, y, z, w).""" + from reflex.runtime.ros2_bridge import _extract_imu + msg = SimpleNamespace( + orientation=SimpleNamespace(x=0.1, y=0.2, z=0.3, w=0.9), + ) + assert _extract_imu(msg) == [0.1, 0.2, 0.3, 0.9] + + def test_output_is_four_dof_partial_state(self): + from reflex.runtime.ros2_bridge import _extract_imu + msg = SimpleNamespace( + orientation=SimpleNamespace(x=0.0, y=0.0, z=0.0, w=1.0), + ) + assert len(_extract_imu(msg)) == 4 + + +class TestOdometryExtractor: + @staticmethod + def _mock_odom(*, p, o, v): + return SimpleNamespace( + pose=SimpleNamespace(pose=SimpleNamespace( + position=SimpleNamespace(x=p[0], y=p[1], z=p[2]), + orientation=SimpleNamespace(x=o[0], y=o[1], z=o[2], w=o[3]), + )), + twist=SimpleNamespace(twist=SimpleNamespace( + linear=SimpleNamespace(x=v[0], y=v[1], z=v[2]), + )), + ) + + def test_extracts_pos_orient_vel_in_order(self): + from reflex.runtime.ros2_bridge import _extract_odom + msg = self._mock_odom( + p=(1.0, 2.0, 3.0), + o=(0.0, 0.0, 0.0, 1.0), + v=(0.1, 0.2, 0.3), + ) + assert _extract_odom(msg) == [1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.1, 0.2, 0.3] + + def test_output_is_ten_dof(self): + from reflex.runtime.ros2_bridge import _extract_odom + msg = self._mock_odom(p=(0, 0, 0), o=(0, 0, 0, 1), v=(0, 0, 0)) + assert len(_extract_odom(msg)) == 10 + + def test_matches_quadcopter_preset_state_dim(self): + """Cross-layer pin: if the quadcopter preset state_dim ever drifts + from the odom extractor length, this breaks at CI time at the right + layer to catch which side moved.""" + from reflex.embodiments import EmbodimentConfig + from reflex.runtime.ros2_bridge import _extract_odom + cfg = EmbodimentConfig.load_preset("quadcopter") + msg = self._mock_odom(p=(0, 0, 0), o=(0, 0, 0, 1), v=(0, 0, 0)) + assert len(_extract_odom(msg)) == cfg.state_dim + + +class TestRegistryAndResolution: + def test_all_documented_types_registered(self): + from reflex.runtime.ros2_bridge import _STATE_EXTRACTORS + assert set(_STATE_EXTRACTORS) == {"joint_state", "imu", "odom"} + + def test_resolve_unknown_type_raises(self): + from reflex.runtime.ros2_bridge import _resolve_state_msg_class + with pytest.raises(ValueError, match="unknown state_msg_type"): + _resolve_state_msg_class("not_a_real_msg_type") + + +# --------------------------------------------------------------------------- +# Integration tests — verify create_ros2_bridge_node dispatches the right +# extractor for each state_msg_type and surfaces mismatch warnings. +# --------------------------------------------------------------------------- + + +class TestStateMsgTypeDispatch: + def test_default_is_joint_state(self, monkeypatch): + _install_fake_rclpy(monkeypatch) + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + node = create_ros2_bridge_node(MagicMock()) + msg = SimpleNamespace(position=[0.1, 0.2, 0.3]) + node._state_cb(msg) + assert node._last_state == [0.1, 0.2, 0.3] + + def test_imu_dispatch(self, monkeypatch): + _install_fake_rclpy(monkeypatch) + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + node = create_ros2_bridge_node( + MagicMock(), state_msg_type="imu", state_topic="/mavros/imu/data", + ) + msg = SimpleNamespace( + orientation=SimpleNamespace(x=0.0, y=0.0, z=0.0, w=1.0), + ) + node._state_cb(msg) + assert node._last_state == [0.0, 0.0, 0.0, 1.0] + + def test_odom_dispatch(self, monkeypatch): + _install_fake_rclpy(monkeypatch) + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + node = create_ros2_bridge_node( + MagicMock(), + state_msg_type="odom", + state_topic="/mavros/local_position/odom", + ) + msg = TestOdometryExtractor._mock_odom( + p=(1.0, 2.0, 3.0), + o=(0.0, 0.0, 0.0, 1.0), + v=(0.1, 0.2, 0.3), + ) + node._state_cb(msg) + assert node._last_state == [ + 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.1, 0.2, 0.3, + ] + + def test_unknown_msg_type_raises(self, monkeypatch): + _install_fake_rclpy(monkeypatch) + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + with pytest.raises(ValueError, match="unknown state_msg_type"): + create_ros2_bridge_node(MagicMock(), state_msg_type="lidar") + + def test_extractor_failure_logged_not_raised(self, monkeypatch): + """A malformed message must not crash the node — log and skip.""" + _install_fake_rclpy(monkeypatch) + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + node = create_ros2_bridge_node(MagicMock(), state_msg_type="imu") + # IMU extractor expects msg.orientation — feed a JointState-shaped msg + bad_msg = SimpleNamespace(position=[0.1, 0.2]) + node._state_cb(bad_msg) + # _last_state stays None, no exception + assert node._last_state is None + + def test_state_dim_mismatch_warns_once(self, monkeypatch): + """When a loaded embodiment expects state_dim=10 (quadcopter) but + the extractor returns 4 (imu), surface a warning. Fires once per + node — the same mismatch every tick would spam the log.""" + _install_fake_rclpy(monkeypatch) + from reflex.embodiments import EmbodimentConfig + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + + server = MagicMock() + server.embodiment_config = EmbodimentConfig.load_preset("quadcopter") + + warnings: list[str] = [] + original_create = create_ros2_bridge_node + + node = original_create(server, state_msg_type="imu") + # Capture warnings from the node's logger (FakeNode.get_logger returns + # a fresh MagicMock each call, so monkey-patch at the node level). + node.get_logger = lambda warnings_ref=warnings: SimpleNamespace( + info=lambda *a, **k: None, + warning=lambda m, *a, **k: warnings_ref.append(str(m)), + error=lambda *a, **k: None, + ) + + msg = SimpleNamespace( + orientation=SimpleNamespace(x=0.0, y=0.0, z=0.0, w=1.0), + ) + node._state_cb(msg) + node._state_cb(msg) # second tick — should NOT re-warn + node._state_cb(msg) + mismatch_warnings = [ + w for w in warnings if "does NOT match embodiment state_dim" in w + ] + assert len(mismatch_warnings) == 1, mismatch_warnings + + def test_state_dim_match_no_warning(self, monkeypatch): + """When extractor output length matches embodiment state_dim, no + warning fires.""" + _install_fake_rclpy(monkeypatch) + from reflex.embodiments import EmbodimentConfig + from reflex.runtime.ros2_bridge import create_ros2_bridge_node + + server = MagicMock() + server.embodiment_config = EmbodimentConfig.load_preset("quadcopter") + + warnings: list[str] = [] + node = create_ros2_bridge_node(server, state_msg_type="odom") + node.get_logger = lambda warnings_ref=warnings: SimpleNamespace( + info=lambda *a, **k: None, + warning=lambda m, *a, **k: warnings_ref.append(str(m)), + error=lambda *a, **k: None, + ) + msg = TestOdometryExtractor._mock_odom( + p=(0, 0, 0), o=(0, 0, 0, 1), v=(0, 0, 0), + ) + node._state_cb(msg) + mismatch_warnings = [ + w for w in warnings if "does NOT match embodiment state_dim" in w + ] + assert mismatch_warnings == []