Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/reflex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,7 +2142,19 @@ def ros2_serve(
),
state_topic: str = typer.Option(
"/joint_states",
help="sensor_msgs/JointState topic — .position field becomes the state vector",
help="State topic. Default reads .position from a JointState (arm "
"convention). For drones, pair --state-msg-type=odom with "
"/mavros/local_position/odom for full 10-DOF state, or --state-"
"msg-type=imu with /mavros/imu/data for 4-DOF orientation-only.",
),
state_msg_type: str = typer.Option(
"joint_state",
"--state-msg-type",
help="How to interpret messages on --state-topic. One of: "
"'joint_state' (sensor_msgs/JointState .position, arms; default), "
"'imu' (sensor_msgs/Imu .orientation quaternion, drone partial "
"state — 4 DOF), 'odom' (nav_msgs/Odometry pose + linear twist, "
"drone full state — 10 DOF, matches the quadcopter preset).",
),
task_topic: str = typer.Option(
"/reflex/task",
Expand Down Expand Up @@ -2208,7 +2220,7 @@ def ros2_serve(
console.print(f" export_dir: {export_dir}")
console.print(f" node_name: {node_name}")
console.print(f" rate_hz: {rate_hz}")
console.print(f" subs: {image_topic}, {state_topic}, {task_topic}")
console.print(f" subs: {image_topic}, {state_topic} ({state_msg_type}), {task_topic}")
console.print(f" pub: {action_topic}")
if mcp:
console.print(f" mcp: {mcp_transport}" + (f" (port {mcp_port})" if mcp_transport == "http" else ""))
Expand All @@ -2223,10 +2235,16 @@ def ros2_serve(
action_topic=action_topic,
rate_hz=rate_hz,
node_name=node_name,
state_msg_type=state_msg_type,
mcp=mcp,
mcp_transport=mcp_transport,
mcp_port=mcp_port,
)
except ValueError as exc:
# _resolve_state_msg_class / create_ros2_bridge_node raise ValueError
# for unknown --state-msg-type values.
console.print(f"[red]{exc}[/red]")
raise typer.Exit(1)
except ImportError as exc:
console.print(f"[red]{exc}[/red]")
raise typer.Exit(2)
Expand Down
167 changes: 157 additions & 10 deletions src/reflex/runtime/ros2_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@
Default topic layout (override via CLI flags):
subs:
/camera/image_raw sensor_msgs/msg/Image (rgb8)
/joint_states sensor_msgs/msg/JointState (positions → state vector)
/joint_states state topic (default JointState — see below)
/reflex/task std_msgs/msg/String (text instruction)
pub:
/reflex/actions std_msgs/msg/Float32MultiArray (flat chunk × action_dim)

State extractors (selected via --state-msg-type):
joint_state sensor_msgs/msg/JointState .position (arms; default)
imu sensor_msgs/msg/Imu .orientation (quat) (drone, partial)
odom nav_msgs/msg/Odometry pose + linear twist (drone, full state)

Recommended drone topic: /mavros/local_position/odom (nav_msgs/Odometry) —
gives the policy position + orientation + linear velocity in one message,
matching the 10-DOF state shape used by the shipped quadcopter preset.
The IMU path only yields 4 DOF (orientation only) — useful as a fallback
when full odometry isn't available, but expect reduced control quality.
"""
from __future__ import annotations

Expand All @@ -30,13 +41,19 @@


def _require_rclpy():
"""Import rclpy + message modules or raise a helpful ImportError."""
"""Import rclpy + always-needed message modules or raise a helpful ImportError.

Returns the core types every bridge configuration needs. State-specific
message classes (Imu, Odometry) are resolved lazily via
`_resolve_state_msg` so a missing optional package doesn't break arm
deployments that only need JointState.
"""
try:
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image, JointState
from sensor_msgs.msg import Image
from std_msgs.msg import Float32MultiArray, String
return rclpy, Node, Image, JointState, String, Float32MultiArray
return rclpy, Node, Image, String, Float32MultiArray
except ImportError as exc:
raise ImportError(
"rclpy not available. The reflex ROS2 bridge requires a ROS2 install "
Expand All @@ -48,6 +65,87 @@ def _require_rclpy():
) from exc


# ---------------------------------------------------------------------------
# State extractors — one per supported state message type. Each takes a
# duck-typed message object and returns a flat list[float] state vector.
#
# These are deliberately decoupled from any rclpy/ROS2 imports so they can
# be unit-tested with simple SimpleNamespace mocks on machines without ROS2.
# ---------------------------------------------------------------------------


def _extract_joint_state(msg: Any) -> list[float]:
"""sensor_msgs/JointState → joint positions (arm convention)."""
return [float(x) for x in msg.position]


def _extract_imu(msg: Any) -> list[float]:
"""sensor_msgs/Imu → quaternion orientation [x, y, z, w] (4 DOF).

Partial state — no position, no velocity. Use `odom` instead when you
have access to /mavros/local_position/odom for full drone state.
"""
o = msg.orientation
return [float(o.x), float(o.y), float(o.z), float(o.w)]


def _extract_odom(msg: Any) -> list[float]:
"""nav_msgs/Odometry → pos (3) + quat orient (4) + linear vel (3) = 10 DOF.

Matches the state shape declared in the shipped quadcopter preset.
The canonical drone topic is /mavros/local_position/odom.
"""
pose = msg.pose.pose
twist = msg.twist.twist
p = pose.position
o = pose.orientation
v = twist.linear
return [
float(p.x), float(p.y), float(p.z),
float(o.x), float(o.y), float(o.z), float(o.w),
float(v.x), float(v.y), float(v.z),
]


# State-msg-type -> (lazy msg class loader, extractor). The class loader is
# called only when the bridge is actually set up — keeps the optional ROS2
# package deps out of the import path for unrelated codepaths.
_STATE_EXTRACTORS: dict[str, tuple[str, Any]] = {
"joint_state": ("joint_state", _extract_joint_state),
"imu": ("imu", _extract_imu),
"odom": ("odom", _extract_odom),
}


def _resolve_state_msg_class(state_msg_type: str) -> Any:
"""Lazy-import the ROS2 message class for the requested state type.

Raises ValueError on unknown type, ImportError if the relevant ROS2
package isn't installed (with a hint about which apt/robostack package
provides it).
"""
if state_msg_type == "joint_state":
from sensor_msgs.msg import JointState
return JointState
if state_msg_type == "imu":
from sensor_msgs.msg import Imu
return Imu
if state_msg_type == "odom":
try:
from nav_msgs.msg import Odometry
except ImportError as exc:
raise ImportError(
"nav_msgs not installed. Required for --state-msg-type=odom. "
"On a ROS2 install, this ships with ros-<distro>-nav-msgs "
"(apt) or the robostack ros-<distro>-nav-msgs conda package."
) from exc
return Odometry
raise ValueError(
f"unknown state_msg_type {state_msg_type!r}; expected one of: "
f"{', '.join(sorted(_STATE_EXTRACTORS))}"
)


def create_ros2_bridge_node(
server: Any,
*,
Expand All @@ -57,15 +155,35 @@ def create_ros2_bridge_node(
action_topic: str = "/reflex/actions",
rate_hz: float = 20.0,
node_name: str = "reflex_vla",
state_msg_type: str = "joint_state",
) -> Any:
"""Build a ROS2 node that wraps ``server.predict()`` as pub/sub.

The returned node subscribes to image + state + task topics, caches the
latest message from each, and at ``rate_hz`` Hz invokes
``server.predict(image, instruction, state)`` and publishes the action
chunk to ``action_topic`` as a flat Float32MultiArray.

state_msg_type selects how the state vector is extracted:
- "joint_state": sensor_msgs/JointState.position (default; arms)
- "imu": sensor_msgs/Imu.orientation (drone; 4-DOF partial)
- "odom": nav_msgs/Odometry pose + twist (drone; 10-DOF full)
"""
rclpy, Node, Image, JointState, String, Float32MultiArray = _require_rclpy()
rclpy, Node, Image, String, Float32MultiArray = _require_rclpy()
if state_msg_type not in _STATE_EXTRACTORS:
raise ValueError(
f"unknown state_msg_type {state_msg_type!r}; expected one of: "
f"{', '.join(sorted(_STATE_EXTRACTORS))}"
)
StateMsgClass = _resolve_state_msg_class(state_msg_type)
_, state_extractor = _STATE_EXTRACTORS[state_msg_type]

# If the loaded server has an embodiment config, surface a one-time
# warning when the extracted state vector length doesn't match what the
# embodiment expects. Silent mismatches caused real drone bugs (#121).
expected_state_dim = getattr(
getattr(server, "embodiment_config", None), "state_dim", None
)

class ReflexROS2Node(Node):
"""Bridge node. Also implements the `mcp.ros2_tools.ROS2Context` protocol
Expand All @@ -78,18 +196,19 @@ def __init__(self) -> None:
self._last_state: list[float] | None = None
self._last_task: str = ""
self._inference_count = 0
self._state_dim_mismatch_warned = False

self.create_subscription(Image, image_topic, self._image_cb, 10)
self.create_subscription(JointState, state_topic, self._state_cb, 10)
self.create_subscription(StateMsgClass, state_topic, self._state_cb, 10)
self.create_subscription(String, task_topic, self._task_cb, 10)
self._action_pub = self.create_publisher(Float32MultiArray, action_topic, 10)
self._estop_pub = self.create_publisher(String, "/reflex/e_stop", 10)
self._timer = self.create_timer(1.0 / max(0.1, rate_hz), self._tick)

self.get_logger().info(
f"reflex ros2 node '{node_name}' up: subs={image_topic} + "
f"{state_topic} + {task_topic}, pub={action_topic} at "
f"{rate_hz:.1f} Hz"
f"{state_topic} ({state_msg_type}) + {task_topic}, "
f"pub={action_topic} at {rate_hz:.1f} Hz"
)

def _image_cb(self, msg: Any) -> None:
Expand Down Expand Up @@ -119,7 +238,33 @@ def _image_cb(self, msg: Any) -> None:
)

def _state_cb(self, msg: Any) -> None:
self._last_state = [float(x) for x in msg.position]
try:
state = state_extractor(msg)
except (AttributeError, TypeError) as exc:
self.get_logger().error(
f"failed to extract state from {state_msg_type!r} message "
f"on {state_topic}: {exc}. Check --state-msg-type matches "
f"the actual message type being published."
)
return
self._last_state = state
# One-time warning on state-dim mismatch — silent shape mismatches
# are the most common drone deployment failure mode (silently
# garbage actions instead of a loud error). Fire once per node.
if (
expected_state_dim is not None
and not self._state_dim_mismatch_warned
and len(state) != expected_state_dim
):
self.get_logger().warning(
f"state vector length {len(state)} from "
f"{state_msg_type!r} extractor does NOT match embodiment "
f"state_dim {expected_state_dim}. The policy will receive "
f"the wrong shape and likely produce nonsense actions. "
f"Either change --state-msg-type or load an embodiment "
f"config whose state_dim matches your robot's state."
)
self._state_dim_mismatch_warned = True

def _task_cb(self, msg: Any) -> None:
self._last_task = str(msg.data)
Expand Down Expand Up @@ -218,6 +363,7 @@ def run_ros2_bridge(
action_topic: str = "/reflex/actions",
rate_hz: float = 20.0,
node_name: str = "reflex_vla",
state_msg_type: str = "joint_state",
mcp: bool = False,
mcp_transport: str = "stdio",
mcp_port: int = 8001,
Expand All @@ -237,7 +383,7 @@ def run_ros2_bridge(
- mcp=True, transport="http": MCP runs in background thread on
``mcp_port``; ``rclpy.spin(node)`` blocks on main thread.
"""
rclpy, _, _, _, _, _ = _require_rclpy()
rclpy, _, _, _, _ = _require_rclpy()
from reflex.runtime.server import ReflexServer

if mcp and mcp_transport not in ("stdio", "http"):
Expand Down Expand Up @@ -266,6 +412,7 @@ def run_ros2_bridge(
action_topic=action_topic,
rate_hz=rate_hz,
node_name=node_name,
state_msg_type=state_msg_type,
)

if not mcp:
Expand Down
Loading
Loading