diff --git a/Cargo.toml b/Cargo.toml index aaa64d3e..61c8981f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "crates/hiroz-schema", "crates/hiroz-msgs", "crates/hiroz-tests", + "crates/hiroz-tf", "crates/hiroz-console", "crates/hiroz-bridge", "crates/hiroz/examples/protobuf_demo", diff --git a/crates/hiroz-codegen/assets/jazzy/dependencies.json b/crates/hiroz-codegen/assets/jazzy/dependencies.json index 4790cb8e..884acbf4 100644 --- a/crates/hiroz-codegen/assets/jazzy/dependencies.json +++ b/crates/hiroz-codegen/assets/jazzy/dependencies.json @@ -28,6 +28,11 @@ "std_msgs" ] }, + "tf2_msgs": { + "dependencies": [ + "geometry_msgs" + ] + }, "sensor_msgs": { "dependencies": [ "builtin_interfaces", diff --git a/crates/hiroz-codegen/assets/jazzy/tf2_msgs/msg/TFMessage.msg b/crates/hiroz-codegen/assets/jazzy/tf2_msgs/msg/TFMessage.msg new file mode 100644 index 00000000..fda1e4d0 --- /dev/null +++ b/crates/hiroz-codegen/assets/jazzy/tf2_msgs/msg/TFMessage.msg @@ -0,0 +1 @@ +geometry_msgs/TransformStamped[] transforms diff --git a/crates/hiroz-codegen/src/bin/export_json.rs b/crates/hiroz-codegen/src/bin/export_json.rs index 159fadd0..7cecf252 100644 --- a/crates/hiroz-codegen/src/bin/export_json.rs +++ b/crates/hiroz-codegen/src/bin/export_json.rs @@ -116,6 +116,7 @@ fn main() -> Result<()> { external_crate: None, local_packages: std::collections::HashSet::new(), json_out: Some(json_path.clone()), + protobuf_excluded_packages: std::collections::HashSet::new(), }; let generator = hiroz_codegen::MessageGenerator::new(config); diff --git a/crates/hiroz-codegen/src/lib.rs b/crates/hiroz-codegen/src/lib.rs index a7e49dc5..cbfd7677 100644 --- a/crates/hiroz-codegen/src/lib.rs +++ b/crates/hiroz-codegen/src/lib.rs @@ -44,6 +44,10 @@ pub struct GeneratorConfig { /// Output JSON definitions for external generators (Go, Python, etc.) pub json_out: Option, + + /// Packages to skip during protobuf generation (CDR generation is unaffected). + /// Useful when a package's message types are intentionally not exposed via protobuf. + pub protobuf_excluded_packages: std::collections::HashSet, } /// Message generator that orchestrates parsing, resolution, and code generation @@ -507,6 +511,18 @@ impl MessageGenerator { fn generate_protobuf_types(&self, messages: &[ResolvedMessage]) -> Result<()> { use crate::protobuf_generator::ProtobufMessageGenerator; + let filtered: Vec = messages + .iter() + .filter(|m| { + !self + .config + .protobuf_excluded_packages + .contains(&m.parsed.package) + }) + .cloned() + .collect(); + let messages = filtered.as_slice(); + let proto_dir = self.config.output_dir.join("proto"); let generator = ProtobufMessageGenerator::new(&proto_dir); @@ -704,6 +720,7 @@ pub fn generate_user_messages(output_dir: &Path, is_humble: bool) -> Result<()> external_crate: Some("hiroz_msgs".to_string()), local_packages, json_out: None, + protobuf_excluded_packages: std::collections::HashSet::new(), }; let generator = MessageGenerator::new(config); diff --git a/crates/hiroz-codegen/src/protobuf_generator.rs b/crates/hiroz-codegen/src/protobuf_generator.rs index 047ac265..35cb4732 100644 --- a/crates/hiroz-codegen/src/protobuf_generator.rs +++ b/crates/hiroz-codegen/src/protobuf_generator.rs @@ -322,15 +322,53 @@ impl ::hiroz::msg::ZMessage for {proto_type} {{ Ok(impls) } - /// Convert ROS message name to prost naming convention + /// Convert a ROS PascalCase message name to prost's UpperCamelCase convention. + /// + /// Prost (via heck) splits words at: + /// - lowercase → uppercase transitions ("fooBar" → ["foo","Bar"]) + /// - uppercase-run boundary before a lowercase letter + /// ("TFMessage" → ["TF","Message"], "ColorRGBA" → ["Color","RGBA"]) + /// + /// Each word is then title-cased (first char upper, rest lower). fn convert_to_prost_naming(&self, name: &str) -> String { - // Handle specific known cases where prost naming differs - match name { - "MultiDOFJointState" => "MultiDofJointState".to_string(), - "ColorRGBA" => "ColorRgba".to_string(), - "UUID" => "Uuid".to_string(), - // Add more mappings as needed - _ => name.to_string(), + let chars: Vec = name.chars().collect(); + let mut words: Vec = Vec::new(); + let mut current = String::new(); + + for (i, &c) in chars.iter().enumerate() { + let is_word_start = if c.is_uppercase() { + let prev = if i > 0 { Some(chars[i - 1]) } else { None }; + let next = chars.get(i + 1).copied(); + match (prev, next) { + (None, _) => false, // first char: never a boundary + (Some(p), _) if p.is_lowercase() => true, // lower→upper + (Some(p), _) if p.is_ascii_digit() => true, // digit→upper + (Some(p), Some(n)) if p.is_uppercase() && n.is_lowercase() => true, // run→lower + _ => false, + } + } else { + false + }; + + if is_word_start && !current.is_empty() { + words.push(current.clone()); + current = String::new(); + } + current.push(c); } + if !current.is_empty() { + words.push(current); + } + + words + .iter() + .map(|w| { + let mut cs = w.chars(); + match cs.next() { + None => String::new(), + Some(first) => first.to_uppercase().to_string() + &cs.as_str().to_lowercase(), + } + }) + .collect() } } diff --git a/crates/hiroz-msgs/Cargo.toml b/crates/hiroz-msgs/Cargo.toml index 2f81f1c6..ebe87cd0 100644 --- a/crates/hiroz-msgs/Cargo.toml +++ b/crates/hiroz-msgs/Cargo.toml @@ -51,6 +51,7 @@ std_msgs = [] geometry_msgs = ["std_msgs"] sensor_msgs = ["std_msgs", "geometry_msgs"] nav_msgs = ["std_msgs", "geometry_msgs"] +tf2_msgs = ["geometry_msgs"] example_interfaces = [] action_tutorials_interfaces = [] test_msgs = [] @@ -63,6 +64,7 @@ all_msgs = [ "geometry_msgs", "sensor_msgs", "nav_msgs", + "tf2_msgs", "example_interfaces", "action_tutorials_interfaces", "rcl_interfaces", diff --git a/crates/hiroz-msgs/build.rs b/crates/hiroz-msgs/build.rs index ba7c3203..b5bd94e7 100644 --- a/crates/hiroz-msgs/build.rs +++ b/crates/hiroz-msgs/build.rs @@ -14,6 +14,7 @@ fn main() -> Result<()> { println!("cargo::rustc-check-cfg=cfg(has_example_interfaces)"); println!("cargo::rustc-check-cfg=cfg(has_test_msgs)"); println!("cargo::rustc-check-cfg=cfg(has_rcl_interfaces)"); + println!("cargo::rustc-check-cfg=cfg(has_tf2_msgs)"); // Detect ROS version and emit cfg let is_humble = detect_ros_version(); @@ -38,6 +39,7 @@ fn main() -> Result<()> { external_crate: None, // All packages are local in hiroz-msgs local_packages: std::collections::HashSet::new(), // All packages are local json_out: None, // Not needed for Rust codegen + protobuf_excluded_packages: std::collections::HashSet::new(), }; let generator = hiroz_codegen::MessageGenerator::new(config); @@ -255,6 +257,10 @@ fn get_all_packages(is_humble: bool) -> Vec<&'static str> { names.push("nav_msgs"); } + if env::var("CARGO_FEATURE_TF2_MSGS").is_ok() { + names.push("tf2_msgs"); + } + if env::var("CARGO_FEATURE_EXAMPLE_INTERFACES").is_ok() { names.push("example_interfaces"); } diff --git a/crates/hiroz-tests/Cargo.toml b/crates/hiroz-tests/Cargo.toml index 1b90e8e8..efd67b8c 100644 --- a/crates/hiroz-tests/Cargo.toml +++ b/crates/hiroz-tests/Cargo.toml @@ -11,6 +11,7 @@ publish = false [dependencies] hiroz = { path = "../hiroz", default-features = false, features = ["protobuf"] } hiroz-msgs = { path = "../hiroz-msgs", default-features = false, optional = true } +hiroz-tf = { path = "../hiroz-tf", default-features = false, optional = true } hiroz-cdr = { path = "../hiroz-cdr" } hiroz-schema = { path = "../hiroz-schema" } protobuf_demo = { path = "../hiroz/examples/protobuf_demo" } @@ -58,8 +59,16 @@ humble-jazzy-bridge-tests = [ # This enables testing hiroz with DDS-based ROS 2 nodes via zenoh-bridge-ros2dds ros2dds-interop = ["ros-msgs", "hiroz/ros2dds", "hiroz/rmw-zenoh"] -# ROS 2 distro compatibility - propagate to hiroz and hiroz-msgs -humble = ["hiroz/humble", "hiroz-msgs/humble"] -jazzy = ["hiroz/jazzy", "hiroz-msgs/jazzy"] -rolling = ["hiroz/rolling", "hiroz-msgs/rolling"] -kilted = ["hiroz/kilted", "hiroz-msgs/kilted"] +# TF integration tests +tf-tests = [ + "dep:hiroz-tf", + "dep:hiroz-msgs", + "hiroz-msgs/tf2_msgs", + "hiroz-msgs/geometry_msgs", +] + +# ROS 2 distro compatibility - propagate to hiroz, hiroz-msgs, and hiroz-tf +humble = ["hiroz/humble", "hiroz-msgs/humble", "hiroz-tf?/humble"] +jazzy = ["hiroz/jazzy", "hiroz-msgs/jazzy", "hiroz-tf?/jazzy"] +rolling = ["hiroz/rolling", "hiroz-msgs/rolling", "hiroz-tf?/rolling"] +kilted = ["hiroz/kilted", "hiroz-msgs/kilted", "hiroz-tf?/kilted"] diff --git a/crates/hiroz-tests/tests/tf_integration.rs b/crates/hiroz-tests/tests/tf_integration.rs new file mode 100644 index 00000000..873e70f7 --- /dev/null +++ b/crates/hiroz-tests/tests/tf_integration.rs @@ -0,0 +1,337 @@ +//! Integration tests for hiroz-tf Buffer. +//! +//! Tests require a Zenoh router (provided by TestRouter) and compile only +//! when the `tf-tests` feature is enabled. + +#![cfg(feature = "tf-tests")] + +mod common; + +use std::time::Duration; + +use common::{TestRouter, create_hiroz_context_with_endpoint}; +use hiroz::Builder; +use hiroz::qos::{QosDurability, QosHistory, QosProfile, QosReliability}; +use hiroz::time::ZTime; +use hiroz_msgs::builtin_interfaces::Time; +use hiroz_msgs::geometry_msgs::{Quaternion, Transform, TransformStamped, Vector3}; +use hiroz_msgs::std_msgs::Header; +use hiroz_msgs::tf2_msgs::TFMessage; +use hiroz_tf::{Buffer, StaticTransformBroadcaster, TransformBroadcaster, WaitError}; + +fn make_tf(parent: &str, child: &str, sec: i32, x: f64) -> TransformStamped { + TransformStamped { + header: Header { + frame_id: parent.to_string(), + stamp: Time { sec, nanosec: 0 }, + }, + child_frame_id: child.to_string(), + transform: Transform { + translation: Vector3 { x, y: 0.0, z: 0.0 }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }, + } +} + +/// Publish `/tf` and verify the Buffer receives and exposes it. +#[tokio::test(flavor = "multi_thread")] +async fn tf_buffer_receives_dynamic_transform() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_test_node").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + // Publisher node on same router + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_publisher").build().unwrap(); + let tf_pub = pub_node + .create_pub::("/tf") + .with_qos(QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::Volatile, + history: QosHistory::KeepLast(std::num::NonZeroUsize::new(100).unwrap()), + ..Default::default() + }) + .build() + .unwrap(); + + // Wait for subscription to be established + tf_pub + .wait_for_subscription(1, Duration::from_secs(5)) + .await; + + tf_pub + .async_publish(&TFMessage { + transforms: vec![make_tf("map", "odom", 10, 3.0)], + }) + .await + .unwrap(); + + // Give the callback time to fire + tokio::time::sleep(Duration::from_millis(200)).await; + + let tf = buffer + .lookup_transform("map", "odom", ZTime::zero()) + .unwrap(); + assert!( + (tf.transform.translation.x - 3.0).abs() < 1e-6, + "expected x=3.0, got {}", + tf.transform.translation.x + ); +} + +/// `/tf_static` with TransientLocal: new subscriber gets old static transforms. +#[tokio::test(flavor = "multi_thread")] +async fn tf_static_transient_local_replayed_on_connect() { + let router = TestRouter::new(); + + // Publish static transform BEFORE creating the Buffer subscriber + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_static_publisher").build().unwrap(); + let tf_static_pub = pub_node + .create_pub::("/tf_static") + .with_qos(QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::TransientLocal, + history: QosHistory::KeepLast(std::num::NonZeroUsize::new(100).unwrap()), + ..Default::default() + }) + .build() + .unwrap(); + + tf_static_pub + .async_publish(&TFMessage { + transforms: vec![make_tf("map", "sensor", 0, 0.5)], + }) + .await + .unwrap(); + + // Small delay so the publication is stored in the TransientLocal publisher + tokio::time::sleep(Duration::from_millis(200)).await; + + // NOW create the buffer — it subscribes AFTER the publish + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_test_node_static").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + // TransientLocal should replay the stored message + tokio::time::sleep(Duration::from_millis(500)).await; + + assert!( + buffer.can_transform("map", "sensor", ZTime::zero()), + "static transform should be available after TransientLocal replay" + ); + let tf = buffer + .lookup_transform("map", "sensor", ZTime::zero()) + .unwrap(); + assert!((tf.transform.translation.x - 0.5).abs() < 1e-6); +} + +/// Two-frame chain: map→odom + odom→base_link, lookup map←base_link. +#[tokio::test(flavor = "multi_thread")] +async fn tf_two_frame_chain_composes_correctly() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_chain_node").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_chain_publisher").build().unwrap(); + let tf_pub = pub_node + .create_pub::("/tf") + .with_qos(QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::Volatile, + history: QosHistory::KeepLast(std::num::NonZeroUsize::new(100).unwrap()), + ..Default::default() + }) + .build() + .unwrap(); + + tf_pub + .wait_for_subscription(1, Duration::from_secs(5)) + .await; + + // map→odom: x=1, odom→base_link: x=2 → base_link in map = x=3 + tf_pub + .async_publish(&TFMessage { + transforms: vec![ + make_tf("map", "odom", 10, 1.0), + make_tf("odom", "base_link", 10, 2.0), + ], + }) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(200)).await; + + let tf = buffer + .lookup_transform("map", "base_link", ZTime::zero()) + .unwrap(); + assert!( + (tf.transform.translation.x - 3.0).abs() < 1e-5, + "expected x=3.0, got {}", + tf.transform.translation.x + ); +} + +/// `can_transform` returns false before transforms arrive, true after. +#[tokio::test(flavor = "multi_thread")] +async fn can_transform_reflects_availability() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_can_transform_node").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + // Initially false + assert!(!buffer.can_transform("map", "robot", ZTime::zero())); + + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_can_publisher").build().unwrap(); + let tf_pub = pub_node + .create_pub::("/tf") + .with_qos(QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::Volatile, + history: QosHistory::KeepLast(std::num::NonZeroUsize::new(100).unwrap()), + ..Default::default() + }) + .build() + .unwrap(); + + tf_pub + .wait_for_subscription(1, Duration::from_secs(5)) + .await; + tf_pub + .async_publish(&TFMessage { + transforms: vec![make_tf("map", "robot", 10, 1.0)], + }) + .await + .unwrap(); + + tokio::time::sleep(Duration::from_millis(200)).await; + + assert!(buffer.can_transform("map", "robot", ZTime::zero())); +} + +/// `TransformBroadcaster` publishes to `/tf`; `Buffer` on the same network receives it. +#[tokio::test(flavor = "multi_thread")] +async fn broadcaster_dynamic_roundtrip() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_broadcaster_rx").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_broadcaster_tx").build().unwrap(); + let broadcaster = TransformBroadcaster::new(&pub_node).unwrap(); + + // Give the subscription time to establish + tokio::time::sleep(Duration::from_millis(300)).await; + + broadcaster + .send_transform(make_tf("map", "base_link", 5, 2.5)) + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + + let tf = buffer + .lookup_transform("map", "base_link", ZTime::zero()) + .unwrap(); + assert!( + (tf.transform.translation.x - 2.5).abs() < 1e-6, + "expected x=2.5, got {}", + tf.transform.translation.x + ); +} + +/// `StaticTransformBroadcaster` publishes to `/tf_static` with TransientLocal; +/// a `Buffer` created after the publish receives the static transform via cache replay. +#[tokio::test(flavor = "multi_thread")] +async fn broadcaster_static_roundtrip_with_late_joiner() { + let router = TestRouter::new(); + + // Publish static transform BEFORE creating the Buffer + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("tf_static_tx").build().unwrap(); + let broadcaster = StaticTransformBroadcaster::new(&pub_node).unwrap(); + + broadcaster + .send_transform(make_tf("world", "camera_link", 0, 0.3)) + .unwrap(); + + tokio::time::sleep(Duration::from_millis(300)).await; + + // Create Buffer after publication — should get replay + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("tf_static_rx").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + assert!( + buffer.can_transform("world", "camera_link", ZTime::zero()), + "static transform should be available via TransientLocal replay" + ); + let tf = buffer + .lookup_transform("world", "camera_link", ZTime::zero()) + .unwrap(); + assert!((tf.transform.translation.x - 0.3).abs() < 1e-6); +} + +/// `wait_for_transform` returns once a matching transform arrives. +#[tokio::test(flavor = "multi_thread")] +async fn wait_for_transform_returns_when_data_arrives() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("wft_rx").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + let pub_ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let pub_node = pub_ctx.create_node("wft_tx").build().unwrap(); + let broadcaster = TransformBroadcaster::new(&pub_node).unwrap(); + + // Publish after a short delay in the background + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + broadcaster + .send_transform(make_tf("map", "lidar", 10, 1.0)) + .unwrap(); + }); + + let result = buffer + .wait_for_transform("map", "lidar", ZTime::zero(), Some(Duration::from_secs(3))) + .await; + + assert!(result.is_ok(), "expected transform, got: {:?}", result); + assert!((result.unwrap().transform.translation.x - 1.0).abs() < 1e-6); +} + +/// `wait_for_transform` returns `WaitError::Timeout` when no data arrives. +#[tokio::test(flavor = "multi_thread")] +async fn wait_for_transform_times_out() { + let router = TestRouter::new(); + let ctx = create_hiroz_context_with_endpoint(router.endpoint()).unwrap(); + let node = ctx.create_node("wft_timeout_node").build().unwrap(); + let buffer = Buffer::new(&node).unwrap(); + + let result = buffer + .wait_for_transform( + "ghost", + "frame", + ZTime::zero(), + Some(Duration::from_millis(300)), + ) + .await; + + assert!( + matches!(result, Err(WaitError::Timeout)), + "expected Timeout, got: {:?}", + result + ); +} diff --git a/crates/hiroz-tf/Cargo.toml b/crates/hiroz-tf/Cargo.toml new file mode 100644 index 00000000..fe85bc28 --- /dev/null +++ b/crates/hiroz-tf/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "hiroz-tf" +version.workspace = true +edition.workspace = true +description = "TF2 transform listener and buffer for hiroz" + +[dependencies] +hiroz = { path = "../hiroz", default-features = false, features = [ + "rmw-zenoh", +] } +hiroz-msgs = { path = "../hiroz-msgs", default-features = false, features = [ + "tf2_msgs", +] } +parking_lot = { workspace = true } +tokio = { workspace = true, features = ["sync", "time"] } +zenoh = { workspace = true } + +[features] +default = ["jazzy"] +jazzy = ["hiroz/jazzy", "hiroz-msgs/jazzy"] +humble = ["hiroz/humble", "hiroz-msgs/humble"] +rolling = ["hiroz/rolling", "hiroz-msgs/rolling"] +kilted = ["hiroz/kilted", "hiroz-msgs/kilted"] diff --git a/crates/hiroz-tf/README.md b/crates/hiroz-tf/README.md new file mode 100644 index 00000000..187ae747 --- /dev/null +++ b/crates/hiroz-tf/README.md @@ -0,0 +1,11 @@ + + +# hiroz-tf + +TF2 transform listener and buffer for hiroz. + +**📚 [Full Documentation](https://zettascalelabs.github.io/hiroz/chapters/tf.html)** diff --git a/crates/hiroz-tf/src/broadcaster.rs b/crates/hiroz-tf/src/broadcaster.rs new file mode 100644 index 00000000..e9f316bc --- /dev/null +++ b/crates/hiroz-tf/src/broadcaster.rs @@ -0,0 +1,157 @@ +use std::num::NonZeroUsize; + +use hiroz::Builder; +use hiroz::msg::NativeCdrSerdes; +use hiroz::node::ZNode; +use hiroz::pubsub::ZPub; +use hiroz::qos::{QosDurability, QosHistory, QosProfile, QosReliability}; +use hiroz_msgs::geometry_msgs::TransformStamped; +use hiroz_msgs::tf2_msgs::TFMessage; + +type TfPub = ZPub>; + +fn volatile_qos() -> QosProfile { + QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::Volatile, + history: QosHistory::KeepLast(NonZeroUsize::new(100).unwrap()), + ..Default::default() + } +} + +fn transient_local_qos() -> QosProfile { + QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::TransientLocal, + history: QosHistory::KeepLast(NonZeroUsize::new(100).unwrap()), + ..Default::default() + } +} + +/// Publishes dynamic transforms to `/tf` (Volatile durability). +/// +/// Use [`crate::Buffer`] on the same or another node to receive these transforms. +pub struct TransformBroadcaster { + pub_: TfPub, +} + +impl TransformBroadcaster { + /// Create a broadcaster attached to `node`. Declares a publisher on `/tf`. + pub fn new(node: &ZNode) -> zenoh::Result { + let pub_ = node + .create_pub::("/tf") + .with_qos(volatile_qos()) + .build()?; + Ok(Self { pub_ }) + } + + /// Publish a single transform to `/tf`. + pub fn send_transform(&self, tf: TransformStamped) -> zenoh::Result<()> { + self.send_transforms(vec![tf]) + } + + /// Publish multiple transforms to `/tf` in a single message. + pub fn send_transforms(&self, transforms: Vec) -> zenoh::Result<()> { + self.pub_.publish(&TFMessage { transforms }) + } +} + +/// Publishes static transforms to `/tf_static` (TransientLocal durability). +/// +/// Late-joining subscribers automatically receive all previously published +/// static transforms via `PublicationCache` replay. +/// +/// All timestamps are unconditionally set to `{sec: 0, nanosec: 0}` before +/// publishing, which is required by the tf2 standard for `/tf_static` messages +/// and ensures interoperability with ROS 2 tf2 clients and rviz2. +pub struct StaticTransformBroadcaster { + pub_: TfPub, +} + +impl StaticTransformBroadcaster { + /// Create a static broadcaster attached to `node`. Declares a publisher on `/tf_static`. + pub fn new(node: &ZNode) -> zenoh::Result { + let pub_ = node + .create_pub::("/tf_static") + .with_qos(transient_local_qos()) + .build()?; + Ok(Self { pub_ }) + } + + /// Publish a single static transform to `/tf_static`. + /// + /// The timestamp in `tf` is overwritten with zero before publishing. + pub fn send_transform(&self, tf: TransformStamped) -> zenoh::Result<()> { + self.send_transforms(vec![tf]) + } + + /// Publish multiple static transforms to `/tf_static` in a single message. + /// + /// All timestamps are overwritten with zero before publishing. + pub fn send_transforms(&self, transforms: Vec) -> zenoh::Result<()> { + self.pub_.publish(&TFMessage { + transforms: zero_timestamps(transforms), + }) + } +} + +/// Zero all timestamps in `transforms`, as required by the tf2 standard for `/tf_static`. +fn zero_timestamps(transforms: Vec) -> Vec { + transforms + .into_iter() + .map(|mut tf| { + tf.header.stamp = hiroz_msgs::builtin_interfaces::Time { sec: 0, nanosec: 0 }; + tf + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use hiroz_msgs::builtin_interfaces::Time; + use hiroz_msgs::geometry_msgs::{Quaternion, Transform, Vector3}; + use hiroz_msgs::std_msgs::Header; + + fn make_tf(parent: &str, child: &str, sec: i32) -> TransformStamped { + TransformStamped { + header: Header { + frame_id: parent.to_string(), + stamp: Time { sec, nanosec: 500 }, + }, + child_frame_id: child.to_string(), + transform: Transform { + translation: Vector3 { + x: 1.0, + y: 0.0, + z: 0.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }, + } + } + + #[test] + fn zero_timestamps_clears_all_stamps() { + let tfs = vec![make_tf("map", "odom", 10), make_tf("odom", "base_link", 20)]; + let zeroed = zero_timestamps(tfs); + for tf in &zeroed { + assert_eq!(tf.header.stamp.sec, 0); + assert_eq!(tf.header.stamp.nanosec, 0); + } + } + + #[test] + fn zero_timestamps_preserves_other_fields() { + let tfs = vec![make_tf("map", "sensor", 5)]; + let zeroed = zero_timestamps(tfs); + assert_eq!(zeroed[0].header.frame_id, "map"); + assert_eq!(zeroed[0].child_frame_id, "sensor"); + assert!((zeroed[0].transform.translation.x - 1.0).abs() < 1e-10); + } +} diff --git a/crates/hiroz-tf/src/buffer.rs b/crates/hiroz-tf/src/buffer.rs new file mode 100644 index 00000000..e94f6b53 --- /dev/null +++ b/crates/hiroz-tf/src/buffer.rs @@ -0,0 +1,272 @@ +use std::collections::{BTreeMap, HashMap}; +use std::sync::Arc; +use std::time::Duration; + +use hiroz::time::ZTime; +use hiroz_msgs::geometry_msgs::TransformStamped; +use hiroz_msgs::tf2_msgs::TFMessage; +use tokio::sync::Notify; + +/// Default maximum age of dynamic transforms to retain (10 seconds, matching tf2). +pub(crate) const DEFAULT_MAX_HISTORY: Duration = Duration::from_secs(10); + +/// Maximum depth when walking up the frame tree, to guard against cycles. +const MAX_TREE_DEPTH: usize = 100; + +pub(crate) struct BufferInner { + /// Dynamic transforms keyed by child_frame_id → time-sorted entries. + pub(crate) dynamic: HashMap>, + /// Static transforms keyed by child_frame_id → latest entry (no time history needed). + pub(crate) static_: HashMap, + pub(crate) max_history: Duration, + /// Notified on every `add_message` call so `wait_for_transform` can wake up. + pub(crate) notify: Arc, +} + +impl Default for BufferInner { + fn default() -> Self { + Self { + dynamic: HashMap::new(), + static_: HashMap::new(), + max_history: DEFAULT_MAX_HISTORY, + notify: Arc::new(Notify::new()), + } + } +} + +impl BufferInner { + pub(crate) fn add_message(&mut self, msg: TFMessage, is_static: bool) { + for tf in msg.transforms { + self.add_transform(tf, is_static); + } + self.notify.notify_waiters(); + } + + fn add_transform(&mut self, tf: TransformStamped, is_static: bool) { + if is_static { + self.static_.insert(tf.child_frame_id.clone(), tf); + } else { + let stamp = stamp_to_ztime(&tf.header.stamp); + let entries = self.dynamic.entry(tf.child_frame_id.clone()).or_default(); + entries.insert(stamp, tf); + self.prune_old_entries(stamp); + } + } + + fn prune_old_entries(&mut self, now: ZTime) { + let cutoff_nanos = now + .as_unix_nanos() + .saturating_sub(self.max_history.as_nanos() as i64); + let cutoff = ZTime::from_unix_nanos(cutoff_nanos); + + for entries in self.dynamic.values_mut() { + let old_keys: Vec = entries.range(..cutoff).map(|(k, _)| *k).collect(); + for k in old_keys { + entries.remove(&k); + } + } + } + + /// Return all known frame IDs (both dynamic and static children). + pub(crate) fn all_frames(&self) -> Vec { + let mut frames: std::collections::HashSet = std::collections::HashSet::new(); + for key in self.dynamic.keys() { + frames.insert(key.clone()); + } + for key in self.static_.keys() { + frames.insert(key.clone()); + } + frames.into_iter().collect() + } + + /// Walk from `frame` toward the tree root, returning the path + /// `[frame, parent(frame), parent(parent(frame)), ..., root]`. + pub(crate) fn path_to_root(&self, frame: &str) -> Vec { + let mut path = vec![frame.to_string()]; + let mut current = frame.to_string(); + + while path.len() < MAX_TREE_DEPTH { + let parent = self + .static_ + .get(¤t) + .map(|tf| tf.header.frame_id.clone()) + .or_else(|| { + self.dynamic + .get(¤t) + .and_then(|entries| entries.values().next_back()) + .map(|tf| tf.header.frame_id.clone()) + }); + + match parent { + None => break, + Some(p) if path.contains(&p) => break, // cycle guard + Some(p) => { + path.push(p.clone()); + current = p; + } + } + } + + path + } +} + +/// Convert a `builtin_interfaces::Time` stamp to `ZTime`. +pub(crate) fn stamp_to_ztime(stamp: &hiroz_msgs::builtin_interfaces::Time) -> ZTime { + let total_nanos = (stamp.sec as i64) + .saturating_mul(1_000_000_000) + .saturating_add(stamp.nanosec as i64); + ZTime::from_unix_nanos(total_nanos) +} + +#[cfg(test)] +mod tests { + use super::*; + use hiroz_msgs::builtin_interfaces::Time; + use hiroz_msgs::geometry_msgs::{Quaternion, Transform, Vector3}; + use hiroz_msgs::std_msgs::Header; + + fn make_tf(parent: &str, child: &str, sec: i32) -> TransformStamped { + TransformStamped { + header: Header { + frame_id: parent.to_string(), + stamp: Time { sec, nanosec: 0 }, + }, + child_frame_id: child.to_string(), + transform: Transform { + translation: Vector3 { + x: 1.0, + y: 0.0, + z: 0.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }, + } + } + + #[test] + fn add_dynamic_transform_inserts_entry() { + let mut buf = BufferInner::default(); + let tf = make_tf("map", "odom", 100); + buf.add_message( + TFMessage { + transforms: vec![tf], + }, + false, + ); + assert!(buf.dynamic.contains_key("odom")); + } + + #[test] + fn add_static_transform_overwrites_previous() { + let mut buf = BufferInner::default(); + let tf1 = make_tf("map", "sensor", 1); + let tf2 = make_tf("world", "sensor", 2); + buf.add_message( + TFMessage { + transforms: vec![tf1], + }, + true, + ); + buf.add_message( + TFMessage { + transforms: vec![tf2], + }, + true, + ); + assert_eq!(buf.static_["sensor"].header.frame_id, "world"); + } + + #[test] + fn prune_removes_old_entries() { + let mut buf = BufferInner { + max_history: Duration::from_secs(5), + ..Default::default() + }; + // Insert old entry at t=0 + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "odom", 0)], + }, + false, + ); + // Insert recent entry at t=100 — triggers pruning of t=0 (100-0 > 5s) + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "odom", 100)], + }, + false, + ); + let entries = &buf.dynamic["odom"]; + let oldest_sec = entries + .values() + .next() + .map(|tf| tf.header.stamp.sec) + .unwrap(); + assert!(oldest_sec > 0, "old entry at t=0 should have been pruned"); + } + + #[test] + fn path_to_root_single_hop() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "odom", 1)], + }, + false, + ); + let path = buf.path_to_root("odom"); + assert_eq!(path, vec!["odom", "map"]); + } + + #[test] + fn path_to_root_two_hops() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "odom", 1)], + }, + false, + ); + buf.add_message( + TFMessage { + transforms: vec![make_tf("odom", "base_link", 1)], + }, + false, + ); + let path = buf.path_to_root("base_link"); + assert_eq!(path, vec!["base_link", "odom", "map"]); + } + + #[test] + fn path_to_root_unknown_frame_is_just_itself() { + let buf = BufferInner::default(); + let path = buf.path_to_root("unknown"); + assert_eq!(path, vec!["unknown"]); + } + + #[test] + fn all_frames_includes_both_static_and_dynamic() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "odom", 1)], + }, + false, + ); + buf.add_message( + TFMessage { + transforms: vec![make_tf("map", "sensor", 1)], + }, + true, + ); + let frames = buf.all_frames(); + assert!(frames.contains(&"odom".to_string())); + assert!(frames.contains(&"sensor".to_string())); + } +} diff --git a/crates/hiroz-tf/src/lib.rs b/crates/hiroz-tf/src/lib.rs new file mode 100644 index 00000000..b2cbeebf --- /dev/null +++ b/crates/hiroz-tf/src/lib.rs @@ -0,0 +1,283 @@ +//! TF2 transform listener and buffer for hiroz. +//! +//! Subscribes to `/tf` (dynamic) and `/tf_static` (TransientLocal) and provides +//! `lookup_transform` with multi-hop LCA traversal and linear/slerp interpolation. +//! +//! # Quick start +//! +//! ```rust,ignore +//! use hiroz::prelude::*; +//! use hiroz_tf::Buffer; +//! +//! #[tokio::main] +//! async fn main() -> zenoh::Result<()> { +//! let ctx = ZContextBuilder::default() +//! .with_connect_endpoints(["tcp/127.0.0.1:7447"]) +//! .build()?; +//! let node = ctx.create_node("tf_listener").build()?; +//! let buffer = Buffer::new(&node)?; +//! +//! tokio::time::sleep(std::time::Duration::from_millis(500)).await; +//! +//! match buffer.lookup_transform("map", "base_link", ZTime::zero()) { +//! Ok(tf) => println!("x={}", tf.transform.translation.x), +//! Err(e) => eprintln!("lookup failed: {e}"), +//! } +//! Ok(()) +//! } +//! ``` + +use std::fmt; +use std::num::NonZeroUsize; +use std::sync::Arc; +use std::time::Duration; + +use hiroz::msg::NativeCdrSerdes; +use hiroz::node::ZNode; +use hiroz::pubsub::ZSub; +use hiroz::qos::{QosDurability, QosHistory, QosProfile, QosReliability}; +use hiroz::time::ZTime; +use hiroz_msgs::geometry_msgs::TransformStamped; +use hiroz_msgs::tf2_msgs::TFMessage; +use parking_lot::RwLock; +use tokio::sync::Notify; + +mod broadcaster; +mod buffer; +mod lookup; +mod math; + +pub use broadcaster::{StaticTransformBroadcaster, TransformBroadcaster}; + +use buffer::{BufferInner, DEFAULT_MAX_HISTORY}; + +type TfSub = ZSub>; + +/// Error returned by [`Buffer::lookup_transform`]. +#[derive(Debug)] +pub enum LookupError { + /// The requested frame has no known transforms. + UnknownFrame(String), + /// `source` and `target` are in disconnected sub-trees. + NoCommonAncestor { source: String, target: String }, + /// The requested timestamp is outside the stored history window. + ExtrapolationError { + frame: String, + requested: ZTime, + oldest: ZTime, + newest: ZTime, + }, +} + +impl fmt::Display for LookupError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LookupError::UnknownFrame(frame) => { + write!(f, "frame '{frame}' has no known transforms") + } + LookupError::NoCommonAncestor { source, target } => { + write!(f, "no common ancestor between '{source}' and '{target}'") + } + LookupError::ExtrapolationError { + frame, + requested, + oldest, + newest, + } => { + write!( + f, + "requested time {:?} for frame '{}' is outside buffer window [{:?}, {:?}]", + requested, frame, oldest, newest + ) + } + } + } +} + +impl std::error::Error for LookupError {} + +/// Error returned by [`Buffer::wait_for_transform`]. +#[derive(Debug)] +pub enum WaitError { + /// The timeout elapsed before the transform became available. + Timeout, + /// The lookup failed with an error that will not resolve with more time + /// (e.g., disconnected frame trees). + Lookup(LookupError), +} + +impl fmt::Display for WaitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WaitError::Timeout => write!(f, "wait_for_transform timed out"), + WaitError::Lookup(e) => write!(f, "lookup failed permanently: {e}"), + } + } +} + +impl std::error::Error for WaitError {} + +/// TF2 transform buffer and listener. +/// +/// Subscribes to `/tf` and `/tf_static` on the provided node and maintains +/// an in-memory frame tree. Drop this value to cancel the subscriptions. +/// +/// # Separation from broadcasters +/// +/// `Buffer` only *receives* transforms. To publish transforms use +/// [`TransformBroadcaster`] (dynamic) or [`StaticTransformBroadcaster`] (static). +/// The two types are intentionally separate so listener-only nodes have no +/// publishing overhead, and publisher-only nodes have no subscriber overhead. +/// +/// A node that needs both can create a `Buffer` and a broadcaster on the same +/// [`ZNode`]. Transforms published via the broadcaster are relayed through the +/// Zenoh router, so they will be received by any `Buffer` on the same network, +/// including the one on the same node. +/// +/// Create with [`Buffer::new`]. +pub struct Buffer { + inner: Arc>, + notify: Arc, + buffer_duration: Duration, + _tf_sub: TfSub, + _tf_static_sub: TfSub, +} + +impl Buffer { + /// Subscribe to `/tf` and `/tf_static` on `node` and return a new buffer. + pub fn new(node: &ZNode) -> zenoh::Result { + let inner = Arc::new(RwLock::new(BufferInner::default())); + + let dynamic_qos = QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::Volatile, + history: QosHistory::KeepLast(NonZeroUsize::new(100).unwrap()), + ..Default::default() + }; + + let static_qos = QosProfile { + reliability: QosReliability::Reliable, + durability: QosDurability::TransientLocal, + history: QosHistory::KeepLast(NonZeroUsize::new(100).unwrap()), + ..Default::default() + }; + + let inner_dyn = Arc::clone(&inner); + let tf_sub = node + .create_sub::("/tf") + .with_qos(dynamic_qos) + .build_with_callback(move |msg: TFMessage| { + inner_dyn.write().add_message(msg, false); + })?; + + let inner_static = Arc::clone(&inner); + let tf_static_sub = node + .create_sub::("/tf_static") + .with_qos(static_qos) + .build_with_callback(move |msg: TFMessage| { + inner_static.write().add_message(msg, true); + })?; + + let notify = Arc::clone(&inner.read().notify); + Ok(Buffer { + inner, + notify, + buffer_duration: DEFAULT_MAX_HISTORY, + _tf_sub: tf_sub, + _tf_static_sub: tf_static_sub, + }) + } + + /// Look up the transform from `source` frame to `target` frame at `time`. + /// + /// Pass [`ZTime::zero()`] to request the latest available transform. + /// + /// The returned `TransformStamped` maps a point expressed in `source` + /// coordinates into `target` coordinates. + pub fn lookup_transform( + &self, + target: &str, + source: &str, + time: ZTime, + ) -> Result { + self.inner.read().lookup(target, source, time) + } + + /// Return `true` if [`lookup_transform`](Self::lookup_transform) would + /// succeed for the given frames and time. + pub fn can_transform(&self, target: &str, source: &str, time: ZTime) -> bool { + self.inner.read().lookup(target, source, time).is_ok() + } + + /// Return all frame IDs currently known to the buffer. + pub fn all_frames(&self) -> Vec { + self.inner.read().all_frames() + } + + /// Look up the transform from `source` at `source_time` to `target` at + /// `target_time`, routing through `fixed_frame`. + /// + /// Matches the tf2 C++ signature: + /// `lookupTransform(target, target_time, source, source_time, fixed_frame)`. + /// + /// Equivalent to: + /// ```text + /// T(target ← source) = T(target ← fixed_frame, target_time) + /// ∘ T(fixed_frame ← source, source_time) + /// ``` + /// + /// Used when target and source are observed at different times and need to be + /// related through a fixed reference frame (typically `"map"`). + pub fn lookup_transform_full( + &self, + target: &str, + target_time: ZTime, + source: &str, + source_time: ZTime, + fixed_frame: &str, + ) -> Result { + let inner = self.inner.read(); + let t1 = inner.lookup(fixed_frame, source, source_time)?; + let t2 = inner.lookup(target, fixed_frame, target_time)?; + Ok(crate::math::compose_stamped(t2, t1, target, source)) + } + + /// Wait asynchronously until `lookup_transform` succeeds or `timeout` elapses. + /// + /// Pass `None` to use the buffer's default duration (10 seconds). + /// + /// Returns `Err(WaitError::Timeout)` if no transform arrives within the + /// deadline. Returns `Err(WaitError::Lookup(...))` immediately if the + /// frames are in disconnected trees (waiting cannot resolve the error). + pub async fn wait_for_transform( + &self, + target: &str, + source: &str, + time: ZTime, + timeout: Option, + ) -> Result { + let timeout = timeout.unwrap_or(self.buffer_duration); + let deadline = tokio::time::Instant::now() + timeout; + loop { + match self.inner.read().lookup(target, source, time) { + Ok(tf) => return Ok(tf), + Err(LookupError::NoCommonAncestor { + source: s, + target: t, + }) => { + return Err(WaitError::Lookup(LookupError::NoCommonAncestor { + source: s, + target: t, + })); + } + Err(_) => {} + } + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return Err(WaitError::Timeout); + } + // Wait for new data or until the deadline, whichever comes first. + let _ = tokio::time::timeout(remaining, self.notify.notified()).await; + } + } +} diff --git a/crates/hiroz-tf/src/lookup.rs b/crates/hiroz-tf/src/lookup.rs new file mode 100644 index 00000000..46144296 --- /dev/null +++ b/crates/hiroz-tf/src/lookup.rs @@ -0,0 +1,389 @@ +use hiroz::time::ZTime; +use hiroz_msgs::geometry_msgs::{Transform, TransformStamped}; + +use crate::LookupError; +use crate::buffer::BufferInner; +use crate::math; + +impl BufferInner { + /// Perform the full transform lookup from `source` to `target` at `time`. + /// `ZTime::zero()` means "latest available". + pub(crate) fn lookup( + &self, + target: &str, + source: &str, + time: ZTime, + ) -> Result { + // Trivial case: same frame + if target == source { + return Ok(identity_stamped(target)); + } + + let source_path = self.path_to_root(source); + let target_path = self.path_to_root(target); + + // Verify both frames exist somewhere in the known transform tree + if !self.frame_exists_anywhere(source) { + return Err(LookupError::UnknownFrame(source.to_string())); + } + if !self.frame_exists_anywhere(target) { + return Err(LookupError::UnknownFrame(target.to_string())); + } + + // Find the lowest common ancestor + let lca_idx_in_source = source_path + .iter() + .position(|f| target_path.contains(f)) + .ok_or_else(|| LookupError::NoCommonAncestor { + source: source.to_string(), + target: target.to_string(), + })?; + + let lca = &source_path[lca_idx_in_source]; + let lca_idx_in_target = target_path.iter().position(|f| f == lca).unwrap(); + + // Build T_{LCA←source} by composing edges source→p1→...→LCA + let source_to_lca = &source_path[..=lca_idx_in_source]; + let mut t_lca_from_source = math::identity_transform(); + for edge in source_to_lca.windows(2) { + let child = &edge[0]; + let edge_tf = self.interpolate_edge(child, time)?; + t_lca_from_source = math::compose_transforms(&t_lca_from_source, &edge_tf); + } + + // Build T_{LCA←target} by composing edges target→q1→...→LCA + let target_to_lca = &target_path[..=lca_idx_in_target]; + let mut t_lca_from_target = math::identity_transform(); + for edge in target_to_lca.windows(2) { + let child = &edge[0]; + let edge_tf = self.interpolate_edge(child, time)?; + t_lca_from_target = math::compose_transforms(&t_lca_from_target, &edge_tf); + } + + // T_{target←source} = T_{target←LCA} * T_{LCA←source} + // = inv(T_{LCA←target}) ∘ T_{LCA←source} + let t_target_from_lca = math::invert_transform(&t_lca_from_target); + let result = math::compose_transforms(&t_lca_from_source, &t_target_from_lca); + + Ok(TransformStamped { + header: hiroz_msgs::std_msgs::Header { + frame_id: target.to_string(), + stamp: ztime_to_stamp(time), + }, + child_frame_id: source.to_string(), + transform: result, + }) + } + + /// Return true iff `frame` appears anywhere in the stored transform tree + /// (as a child frame OR as a parent/header frame of some stored transform). + pub(crate) fn frame_exists_anywhere(&self, frame: &str) -> bool { + if self.dynamic.contains_key(frame) || self.static_.contains_key(frame) { + return true; + } + // Check if it appears as a parent in any stored transform + self.dynamic.values().any(|entries| { + entries + .values() + .next_back() + .is_some_and(|tf| tf.header.frame_id == frame) + }) || self.static_.values().any(|tf| tf.header.frame_id == frame) + } + + /// Interpolate the stored transform for `child` at `time`. + /// + /// Checks static first (always valid, no time interpolation). + /// Falls back to dynamic with bracketed linear/slerp interpolation. + pub(crate) fn interpolate_edge( + &self, + child: &str, + time: ZTime, + ) -> Result { + // Static transforms are always valid + if let Some(tf) = self.static_.get(child) { + return Ok(tf.transform.clone()); + } + + let entries = self + .dynamic + .get(child) + .ok_or_else(|| LookupError::UnknownFrame(child.to_string()))?; + + if entries.is_empty() { + return Err(LookupError::UnknownFrame(child.to_string())); + } + + // Latest-available sentinel + if time == ZTime::zero() { + let (_, tf) = entries.iter().next_back().unwrap(); + return Ok(tf.transform.clone()); + } + + let oldest = *entries.keys().next().unwrap(); + let newest = *entries.keys().next_back().unwrap(); + + if time < oldest || time > newest { + return Err(LookupError::ExtrapolationError { + frame: child.to_string(), + requested: time, + oldest, + newest, + }); + } + + // Exact match + if let Some(tf) = entries.get(&time) { + return Ok(tf.transform.clone()); + } + + // Interpolate between surrounding entries + let before = entries + .range(..time) + .next_back() + .map(|(t, tf)| (*t, tf.transform.clone())); + let after = entries + .range(time..) + .next() + .map(|(t, tf)| (*t, tf.transform.clone())); + + match (before, after) { + (Some((t0, tf0)), Some((t1, tf1))) => { + let t0_ns = t0.as_unix_nanos(); + let t1_ns = t1.as_unix_nanos(); + let req_ns = time.as_unix_nanos(); + let alpha = (req_ns - t0_ns) as f64 / (t1_ns - t0_ns) as f64; + Ok(math::interpolate_transforms(&tf0, &tf1, alpha)) + } + _ => Err(LookupError::ExtrapolationError { + frame: child.to_string(), + requested: time, + oldest, + newest, + }), + } + } +} + +fn identity_stamped(frame: &str) -> TransformStamped { + TransformStamped { + header: hiroz_msgs::std_msgs::Header { + frame_id: frame.to_string(), + stamp: hiroz_msgs::builtin_interfaces::Time { sec: 0, nanosec: 0 }, + }, + child_frame_id: frame.to_string(), + transform: math::identity_transform(), + } +} + +fn ztime_to_stamp(t: ZTime) -> hiroz_msgs::builtin_interfaces::Time { + let nanos = t.as_unix_nanos().max(0) as u64; + hiroz_msgs::builtin_interfaces::Time { + sec: (nanos / 1_000_000_000) as i32, + nanosec: (nanos % 1_000_000_000) as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::buffer::BufferInner; + use hiroz_msgs::builtin_interfaces::Time; + use hiroz_msgs::geometry_msgs::{Quaternion, Transform, Vector3}; + use hiroz_msgs::std_msgs::Header; + use hiroz_msgs::tf2_msgs::TFMessage; + + fn make_tf_at(parent: &str, child: &str, sec: i32, x: f64) -> TransformStamped { + TransformStamped { + header: Header { + frame_id: parent.to_string(), + stamp: Time { sec, nanosec: 0 }, + }, + child_frame_id: child.to_string(), + transform: Transform { + translation: Vector3 { x, y: 0.0, z: 0.0 }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }, + } + } + + fn t(sec: i32) -> ZTime { + ZTime::from_unix_nanos(sec as i64 * 1_000_000_000) + } + + #[test] + fn lookup_unknown_frame_errors() { + let buf = BufferInner::default(); + assert!(matches!( + buf.lookup("map", "base_link", ZTime::zero()), + Err(LookupError::UnknownFrame(_)) + )); + } + + #[test] + fn lookup_same_frame_is_identity() { + let buf = BufferInner::default(); + let tf = buf.lookup("map", "map", ZTime::zero()).unwrap(); + assert!((tf.transform.translation.x).abs() < 1e-10); + assert!((tf.transform.rotation.w - 1.0).abs() < 1e-10); + } + + #[test] + fn lookup_direct_edge_returns_latest() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 5.0)], + }, + false, + ); + let tf = buf.lookup("map", "odom", ZTime::zero()).unwrap(); + assert!((tf.transform.translation.x - 5.0).abs() < 1e-10); + } + + #[test] + fn lookup_interpolates_at_midpoint() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 0.0)], + }, + false, + ); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 20, 10.0)], + }, + false, + ); + // At t=15 (midpoint), x should be ~5.0 + let tf = buf.lookup("map", "odom", t(15)).unwrap(); + assert!((tf.transform.translation.x - 5.0).abs() < 1e-6); + } + + #[test] + fn lookup_extrapolation_error_outside_window() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 0.0)], + }, + false, + ); + assert!(matches!( + buf.lookup("map", "odom", t(5)), + Err(LookupError::ExtrapolationError { .. }) + )); + } + + #[test] + fn lookup_two_hop_chain() { + let mut buf = BufferInner::default(); + // map→odom: translate x=1 + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 1.0)], + }, + false, + ); + // odom→base_link: translate x=2 + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("odom", "base_link", 10, 2.0)], + }, + false, + ); + // Expected: base_link in map = x=3 + let tf = buf.lookup("map", "base_link", ZTime::zero()).unwrap(); + assert!((tf.transform.translation.x - 3.0).abs() < 1e-6); + } + + #[test] + fn lookup_static_transform_at_any_time() { + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "sensor", 0, 0.5)], + }, + true, + ); + // Static transforms should be returned regardless of requested time + let tf = buf.lookup("map", "sensor", t(9999)).unwrap(); + assert!((tf.transform.translation.x - 0.5).abs() < 1e-10); + } + + #[test] + fn lookup_no_common_ancestor_errors() { + let mut buf = BufferInner::default(); + // Two disconnected trees + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("world_a", "odom_a", 10, 1.0)], + }, + false, + ); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("world_b", "odom_b", 10, 1.0)], + }, + false, + ); + assert!(matches!( + buf.lookup("odom_a", "odom_b", ZTime::zero()), + Err(LookupError::NoCommonAncestor { .. }) + )); + } + + #[test] + fn lookup_inverse_direction() { + let mut buf = BufferInner::default(); + // map→odom: translate x=5 + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 5.0)], + }, + false, + ); + // Lookup odom←map should be inverse: x=-5 + let tf = buf.lookup("odom", "map", ZTime::zero()).unwrap(); + assert!((tf.transform.translation.x - (-5.0)).abs() < 1e-6); + } + + #[test] + fn lookup_full_via_fixed_frame() { + // Set up: map→odom (x=1) and map→camera (x=3). + // lookup_full("camera", "odom", t, "map", t) should give x=2 (camera relative to odom). + let mut buf = BufferInner::default(); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "odom", 10, 1.0)], + }, + false, + ); + buf.add_message( + TFMessage { + transforms: vec![make_tf_at("map", "camera", 10, 3.0)], + }, + false, + ); + + let t = ZTime::zero(); + // T(map ← odom) at t, then T(camera ← map) at t + let t1 = buf.lookup("map", "odom", t).unwrap(); + let t2 = buf.lookup("camera", "map", t).unwrap(); + let result = crate::math::compose_stamped(t2, t1, "camera", "odom"); + // camera is at x=3, odom is at x=1 in map frame. + // odom expressed in camera frame = 1 - 3 = -2 (odom is behind camera). + assert!( + (result.transform.translation.x - (-2.0)).abs() < 1e-5, + "expected x=-2.0, got {}", + result.transform.translation.x + ); + assert_eq!(result.header.frame_id, "camera"); + assert_eq!(result.child_frame_id, "odom"); + } +} diff --git a/crates/hiroz-tf/src/math.rs b/crates/hiroz-tf/src/math.rs new file mode 100644 index 00000000..c968829b --- /dev/null +++ b/crates/hiroz-tf/src/math.rs @@ -0,0 +1,406 @@ +use hiroz_msgs::geometry_msgs::{Quaternion, Transform, TransformStamped, Vector3}; +use hiroz_msgs::std_msgs::Header; + +pub fn identity_transform() -> Transform { + Transform { + translation: Vector3 { + x: 0.0, + y: 0.0, + z: 0.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + } +} + +pub fn quaternion_dot(a: &Quaternion, b: &Quaternion) -> f64 { + a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w +} + +pub fn quaternion_negate(q: &Quaternion) -> Quaternion { + Quaternion { + x: -q.x, + y: -q.y, + z: -q.z, + w: -q.w, + } +} + +pub fn quaternion_normalize(q: &Quaternion) -> Quaternion { + let norm = (q.x * q.x + q.y * q.y + q.z * q.z + q.w * q.w).sqrt(); + if norm < f64::EPSILON { + return Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }; + } + Quaternion { + x: q.x / norm, + y: q.y / norm, + z: q.z / norm, + w: q.w / norm, + } +} + +pub fn quaternion_conjugate(q: &Quaternion) -> Quaternion { + Quaternion { + x: -q.x, + y: -q.y, + z: -q.z, + w: q.w, + } +} + +/// Quaternion product: lhs * rhs. +/// Applying rhs rotation first, then lhs. +pub fn quaternion_multiply(lhs: &Quaternion, rhs: &Quaternion) -> Quaternion { + Quaternion { + x: lhs.w * rhs.x + lhs.x * rhs.w + lhs.y * rhs.z - lhs.z * rhs.y, + y: lhs.w * rhs.y - lhs.x * rhs.z + lhs.y * rhs.w + lhs.z * rhs.x, + z: lhs.w * rhs.z + lhs.x * rhs.y - lhs.y * rhs.x + lhs.z * rhs.w, + w: lhs.w * rhs.w - lhs.x * rhs.x - lhs.y * rhs.y - lhs.z * rhs.z, + } +} + +/// Rotate vector v by quaternion q: q * v * q_conjugate. +pub fn rotate_vector(q: &Quaternion, v: &Vector3) -> Vector3 { + let tx = 2.0 * (q.y * v.z - q.z * v.y); + let ty = 2.0 * (q.z * v.x - q.x * v.z); + let tz = 2.0 * (q.x * v.y - q.y * v.x); + Vector3 { + x: v.x + q.w * tx + q.y * tz - q.z * ty, + y: v.y + q.w * ty + q.z * tx - q.x * tz, + z: v.z + q.w * tz + q.x * ty - q.y * tx, + } +} + +/// Spherical linear interpolation between two quaternions. +/// t=0 returns q0, t=1 returns q1. +pub fn slerp(q0: &Quaternion, q1: &Quaternion, t: f64) -> Quaternion { + let mut dot = quaternion_dot(q0, q1); + // Ensure shortest path + let q1_eff = if dot < 0.0 { + dot = -dot; + quaternion_negate(q1) + } else { + *q1 + }; + // For very close quaternions use normalised linear interpolation + if dot > 0.9995 { + let interp = Quaternion { + x: q0.x + t * (q1_eff.x - q0.x), + y: q0.y + t * (q1_eff.y - q0.y), + z: q0.z + t * (q1_eff.z - q0.z), + w: q0.w + t * (q1_eff.w - q0.w), + }; + return quaternion_normalize(&interp); + } + let theta0 = dot.acos(); + let sin_theta0 = theta0.sin(); + let s0 = ((1.0 - t) * theta0).sin() / sin_theta0; + let s1 = (t * theta0).sin() / sin_theta0; + Quaternion { + x: s0 * q0.x + s1 * q1_eff.x, + y: s0 * q0.y + s1 * q1_eff.y, + z: s0 * q0.z + s1 * q1_eff.z, + w: s0 * q0.w + s1 * q1_eff.w, + } +} + +/// Linear interpolation of translations. +pub fn lerp_vector3(v0: &Vector3, v1: &Vector3, t: f64) -> Vector3 { + Vector3 { + x: v0.x + t * (v1.x - v0.x), + y: v0.y + t * (v1.y - v0.y), + z: v0.z + t * (v1.z - v0.z), + } +} + +/// Interpolate between two transforms at factor t (0=a, 1=b). +pub fn interpolate_transforms(a: &Transform, b: &Transform, t: f64) -> Transform { + Transform { + translation: lerp_vector3(&a.translation, &b.translation, t), + rotation: slerp(&a.rotation, &b.rotation, t), + } +} + +/// compose_transforms(a, b): apply a first, then b. +/// Equivalent to: point_out = b.rotation * (a.rotation * point + a.translation) + b.translation +pub fn compose_transforms(a: &Transform, b: &Transform) -> Transform { + Transform { + translation: Vector3 { + x: rotate_vector(&b.rotation, &a.translation).x + b.translation.x, + y: rotate_vector(&b.rotation, &a.translation).y + b.translation.y, + z: rotate_vector(&b.rotation, &a.translation).z + b.translation.z, + }, + rotation: quaternion_multiply(&b.rotation, &a.rotation), + } +} + +/// Compose two `TransformStamped` values and wrap the result with new frame labels. +/// +/// The result represents `T(target_frame ← source_frame)`: +/// `T(target ← fixed) ∘ T(fixed ← source)`. +/// The result's stamp is taken from `t2` (the target side). +pub fn compose_stamped( + t2: TransformStamped, + t1: TransformStamped, + target_frame: &str, + source_frame: &str, +) -> TransformStamped { + TransformStamped { + header: Header { + frame_id: target_frame.to_string(), + stamp: t2.header.stamp, + }, + child_frame_id: source_frame.to_string(), + transform: compose_transforms(&t1.transform, &t2.transform), + } +} + +/// Invert a transform: T^-1 such that compose(T, T^-1) = identity. +pub fn invert_transform(t: &Transform) -> Transform { + let rot_inv = quaternion_conjugate(&t.rotation); + let trans_inv = rotate_vector( + &rot_inv, + &Vector3 { + x: -t.translation.x, + y: -t.translation.y, + z: -t.translation.z, + }, + ); + Transform { + translation: trans_inv, + rotation: rot_inv, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn approx_eq(a: f64, b: f64) -> bool { + (a - b).abs() < 1e-10 + } + + fn quat_approx_eq(a: &Quaternion, b: &Quaternion) -> bool { + // q and -q represent the same rotation + let same = approx_eq(a.x, b.x) + && approx_eq(a.y, b.y) + && approx_eq(a.z, b.z) + && approx_eq(a.w, b.w); + let neg = approx_eq(a.x, -b.x) + && approx_eq(a.y, -b.y) + && approx_eq(a.z, -b.z) + && approx_eq(a.w, -b.w); + same || neg + } + + fn vec_approx_eq(a: &Vector3, b: &Vector3) -> bool { + approx_eq(a.x, b.x) && approx_eq(a.y, b.y) && approx_eq(a.z, b.z) + } + + fn identity_quat() -> Quaternion { + Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + } + } + + // Unit quaternion for 90° rotation around Z: (0, 0, sin(45°), cos(45°)) + fn q_90z() -> Quaternion { + let h = std::f64::consts::FRAC_PI_4; // π/4 = half-angle for 90° rotation + Quaternion { + x: 0.0, + y: 0.0, + z: h.sin(), + w: h.cos(), + } + } + + #[test] + fn slerp_at_t0_returns_q0() { + let q0 = Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }; + let r = slerp(&q0, &q_90z(), 0.0); + assert!(quat_approx_eq(&r, &q0)); + } + + #[test] + fn slerp_at_t1_returns_q1() { + let q1 = q_90z(); + let q0 = Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }; + let r = slerp(&q0, &q1, 1.0); + assert!(quat_approx_eq(&r, &q1)); + } + + #[test] + fn slerp_at_midpoint_is_normalized() { + let q0 = Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }; + let r = slerp(&q0, &q_90z(), 0.5); + let norm = (r.x * r.x + r.y * r.y + r.z * r.z + r.w * r.w).sqrt(); + assert!(approx_eq(norm, 1.0)); + } + + #[test] + fn compose_with_identity_is_noop() { + let id = identity_transform(); + let t = Transform { + translation: Vector3 { + x: 1.0, + y: 2.0, + z: 3.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.5_f64.sqrt(), + w: 0.5_f64.sqrt(), + }, + }; + let r = compose_transforms(&t, &id); + assert!(vec_approx_eq(&r.translation, &t.translation)); + assert!(quat_approx_eq(&r.rotation, &t.rotation)); + } + + #[test] + fn compose_identity_with_t_is_t() { + let id = identity_transform(); + let t = Transform { + translation: Vector3 { + x: 1.0, + y: 2.0, + z: 3.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.5_f64.sqrt(), + w: 0.5_f64.sqrt(), + }, + }; + let r = compose_transforms(&id, &t); + assert!(vec_approx_eq(&r.translation, &t.translation)); + assert!(quat_approx_eq(&r.rotation, &t.rotation)); + } + + #[test] + fn compose_then_invert_is_identity() { + let t = Transform { + translation: Vector3 { + x: 1.0, + y: 2.0, + z: 3.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.5_f64.sqrt(), + w: 0.5_f64.sqrt(), + }, + }; + let result = compose_transforms(&t, &invert_transform(&t)); + let id = identity_transform(); + assert!(vec_approx_eq(&result.translation, &id.translation)); + assert!(quat_approx_eq(&result.rotation, &id.rotation)); + } + + #[test] + fn rotate_vector_with_identity_is_noop() { + let q = identity_quat(); + let v = Vector3 { + x: 1.0, + y: 2.0, + z: 3.0, + }; + let r = rotate_vector(&q, &v); + assert!(vec_approx_eq(&r, &v)); + } + + #[test] + fn rotate_vector_90_degrees_around_z() { + // 90° rotation around Z: x→y, y→-x + let angle = std::f64::consts::PI / 2.0; + let q = Quaternion { + x: 0.0, + y: 0.0, + z: (angle / 2.0).sin(), + w: (angle / 2.0).cos(), + }; + let v = Vector3 { + x: 1.0, + y: 0.0, + z: 0.0, + }; + let r = rotate_vector(&q, &v); + assert!(vec_approx_eq( + &r, + &Vector3 { + x: 0.0, + y: 1.0, + z: 0.0 + } + )); + } + + #[test] + fn compose_translations_add() { + let a = Transform { + translation: Vector3 { + x: 1.0, + y: 0.0, + z: 0.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }; + let b = Transform { + translation: Vector3 { + x: 2.0, + y: 0.0, + z: 0.0, + }, + rotation: Quaternion { + x: 0.0, + y: 0.0, + z: 0.0, + w: 1.0, + }, + }; + let r = compose_transforms(&a, &b); + assert!(vec_approx_eq( + &r.translation, + &Vector3 { + x: 3.0, + y: 0.0, + z: 0.0 + } + )); + } +} diff --git a/crates/hiroz/src/pubsub.rs b/crates/hiroz/src/pubsub.rs index 5a51466d..3d6ea16a 100644 --- a/crates/hiroz/src/pubsub.rs +++ b/crates/hiroz/src/pubsub.rs @@ -5,6 +5,8 @@ use std::{marker::PhantomData, sync::Arc}; use tracing::{debug, trace, warn}; use zenoh::liveliness::LivelinessToken; use zenoh::{Result, Session, Wait, sample::Sample}; +#[allow(deprecated)] +use zenoh_ext::{PublicationCache, SessionExt}; use crate::Builder; use crate::attachment::{Attachment, GidArray}; @@ -113,6 +115,10 @@ pub struct ZPub { gid: GidArray, inner: AdvancedPublisher<'static>, _lv_token: LivelinessToken, + /// Caches samples for TransientLocal durability so late-joining subscribers + /// can retrieve previously published data via an initial get() query. + #[allow(deprecated)] + _pub_cache: Option, with_attachment: bool, clock: crate::time::ZClock, events_mgr: Arc>, @@ -320,6 +326,7 @@ where debug!("[PUB] Key expression: {}", key_expr); // Map QoS to Zenoh publisher settings + let cache_key_expr = key_expr.clone(); let mut pub_builder = self.session.declare_publisher(key_expr); // Map reliability: Reliable uses Block, BestEffort uses Drop @@ -344,6 +351,35 @@ where let inner = pub_builder.wait()?; debug!("[PUB] Publisher ready: topic={}", self.entity.topic); + // For TransientLocal publishers, declare a PublicationCache that answers + // get() queries from late-joining QueryingSubscribers. + let is_transient_local = + matches!(self.entity.qos.durability, QosDurability::TransientLocal); + #[allow(deprecated)] + let pub_cache: Option = if is_transient_local { + let history = match self.entity.qos.history { + QosHistory::KeepLast(n) => n, + QosHistory::KeepAll => 1000usize, + }; + match self + .session + .declare_publication_cache(&cache_key_expr) + .history(history) + .wait() + { + Ok(cache) => { + debug!("[PUB] PublicationCache declared (history={})", history); + Some(cache) + } + Err(e) => { + warn!("[PUB] Failed to declare PublicationCache: {}", e); + None + } + } + } else { + None + }; + let lv_ke = self .keyexpr_format .liveliness_key_expr(&self.entity, &self.session.zid())?; @@ -367,6 +403,7 @@ where sn: AtomicUsize::new(0), inner, _lv_token: lv_token, + _pub_cache: pub_cache, gid, clock: self.clock, events_mgr: Arc::new(Mutex::new(EventsManager::new(gid))),