diff --git a/dimos/protocol/pubsub/impl/test_lcmpubsub.py b/dimos/protocol/pubsub/impl/test_lcmpubsub.py index ba29c70958..c53bc32da2 100644 --- a/dimos/protocol/pubsub/impl/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/test_lcmpubsub.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Generator -import time from typing import Any import pytest @@ -27,6 +26,7 @@ PickleLCM, Topic, ) +from dimos.utils.testing.collector import CallbackCollector @pytest.fixture @@ -74,25 +74,19 @@ def __eq__(self, other: object) -> bool: def test_LCMPubSubBase_pubsub(lcm_pub_sub_base: LCMPubSubBase) -> None: lcm = lcm_pub_sub_base - - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message.lcm_encode()) - time.sleep(0.1) + collector.wait() - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, bytes) assert received_data.decode() == "test_data" @@ -102,24 +96,19 @@ def callback(msg: Any, topic: Any) -> None: def test_lcm_autodecoder_pubsub(lcm: LCM) -> None: - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) - time.sleep(0.1) + collector.wait() - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, MockLCMMessage) assert received_data == test_message @@ -138,24 +127,18 @@ def callback(msg: Any, topic: Any) -> None: # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_pubsub(test_message: Any, lcm: LCM) -> None: - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) + collector.wait() - time.sleep(0.1) - - assert len(received_messages) == 1 + assert len(collector.results) == 1 - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, test_message.__class__) assert received_data == test_message @@ -163,36 +146,26 @@ def callback(msg: Any, topic: Any) -> None: assert isinstance(received_topic, Topic) assert received_topic == topic - print(test_message, topic) - # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_autopickle_pubsub(test_message: Any, pickle_lcm: PickleLCM) -> None: lcm = pickle_lcm - received_messages: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic = Topic(topic="/test_topic") - def callback(msg: Any, topic: Any) -> None: - received_messages.append((msg, topic)) - - lcm.subscribe(topic, callback) + lcm.subscribe(topic, collector) lcm.publish(topic, test_message) + collector.wait() - time.sleep(0.1) + assert len(collector.results) == 1 - assert len(received_messages) == 1 - - received_data = received_messages[0][0] - received_topic = received_messages[0][1] - - print(f"Received data: {received_data}, Topic: {received_topic}") + received_data = collector.results[0][0] + received_topic = collector.results[0][1] assert isinstance(received_data, test_message.__class__) assert received_data == test_message assert isinstance(received_topic, Topic) assert received_topic == topic - - print(test_message, topic) diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index ef9df74227..6f29b3591b 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Generator -import threading from dimos_lcm.geometry_msgs import PointStamped import numpy as np @@ -28,6 +27,7 @@ # Add msg_name to LCM PointStamped for testing nested message conversion PointStamped.msg_name = "geometry_msgs.PointStamped" from dimos.utils.data import get_data +from dimos.utils.testing.collector import CallbackCollector from dimos.utils.testing.replay import TimedSensorReplay @@ -57,20 +57,14 @@ def test_basic_conversion(publisher, subscriber): Simple flat dimos.msgs type with no nesting (just x/y/z floats). """ topic = ROSTopic("/test_ros_topic", Vector3) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, Vector3(1.0, 2.0, 3.0)) - assert event.wait(timeout=2.0), "No message received" - assert len(received) == 1 - msg = received[0] + collector.wait() + assert len(collector.results) == 1 + msg = collector.results[0][0] assert msg.x == 1.0 assert msg.y == 2.0 assert msg.z == 3.0 @@ -95,21 +89,15 @@ def test_pointcloud2_pubsub(publisher, subscriber): assert len(original) > 0, "Loaded empty pointcloud" topic = ROSTopic("/test_pointcloud2", PointCloud2) + collector = CallbackCollector(1, timeout=5.0) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=5.0), "No PointCloud2 message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify point cloud data is preserved original_points, _ = original.as_numpy() @@ -147,20 +135,14 @@ def test_pointcloud2_empty_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_empty_pointcloud", PointCloud2) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No empty PointCloud2 message received" - assert len(received) == 1 - assert len(received[0]) == 0 + collector.wait() + assert len(collector.results) == 1 + assert len(collector.results[0][0]) == 0 @pytest.mark.skipif_no_ros @@ -178,21 +160,15 @@ def test_posestamped_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_posestamped", PoseStamped) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No PoseStamped message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify all fields preserved assert converted.frame_id == original.frame_id @@ -220,21 +196,15 @@ def test_pointstamped_pubsub(publisher, subscriber): original.point.z = 3.5 topic = ROSTopic("/test_pointstamped", PointStamped) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No PointStamped message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify nested header fields are preserved assert converted.header.frame_id == original.header.frame_id @@ -260,21 +230,15 @@ def test_twist_pubsub(publisher, subscriber): ) topic = ROSTopic("/test_twist", Twist) + collector = CallbackCollector(1) - received = [] - event = threading.Event() - - def callback(msg, t): - received.append(msg) - event.set() - - subscriber.subscribe(topic, callback) + subscriber.subscribe(topic, collector) publisher.publish(topic, original) - assert event.wait(timeout=2.0), "No Twist message received" - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 - converted = received[0] + converted = collector.results[0][0] # Verify linear velocity preserved assert converted.linear.x == original.linear.x diff --git a/dimos/protocol/pubsub/test_pattern_sub.py b/dimos/protocol/pubsub/test_pattern_sub.py index ac94ba1b3b..4b888f4bba 100644 --- a/dimos/protocol/pubsub/test_pattern_sub.py +++ b/dimos/protocol/pubsub/test_pattern_sub.py @@ -30,6 +30,7 @@ from dimos.protocol.pubsub.impl.lcmpubsub import LCM, LCMPubSubBase, Topic from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub, PubSub +from dimos.utils.testing.collector import CallbackCollector TopicT = TypeVar("TopicT") MsgT = TypeVar("MsgT") @@ -139,22 +140,20 @@ def _topic_matches_prefix(topic: Any, prefix: str = "/") -> bool: @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None: """Test that subscribe_all receives messages from all topics.""" - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(tc.topic_values)) with tc.pubsub_context() as (pub, sub): - # Filter to only our test topics (LCM multicast can leak from parallel tests) - sub.subscribe_all(lambda msg, topic: received.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + sub.subscribe_all(collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(tc.topic_values) + assert len(collector.results) == len(tc.topic_values) - # Verify all messages were received - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] expected_msgs = [v for _, v in tc.topic_values] for expected in expected_msgs: assert expected in received_msgs @@ -163,47 +162,45 @@ def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None: @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_unsubscribe(tc: Case[Any, Any]) -> None: """Test that unsubscribe stops receiving messages.""" - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(1) topic, value = tc.topic_values[0] with tc.pubsub_context() as (pub, sub): - unsub = sub.subscribe_all(lambda msg, topic: received.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + unsub = sub.subscribe_all(collector) + time.sleep(0.01) # Allow subscription to register pub.publish(topic, value) - time.sleep(0.01) - assert len(received) == 1 + collector.wait() + assert len(collector.results) == 1 unsub() pub.publish(topic, value) - time.sleep(0.01) - assert len(received) == 1 # No new messages + time.sleep(0.1) # Wait to confirm no new messages arrive + assert len(collector.results) == 1 # No new messages @pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name) def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None: """Test that subscribe_all coexists with regular subscriptions.""" - all_received: list[tuple[Any, Any]] = [] + all_collector = CallbackCollector(2) specific_received: list[tuple[Any, Any]] = [] topic1, value1 = tc.topic_values[0] topic2, value2 = tc.topic_values[1] with tc.pubsub_context() as (pub, sub): sub.subscribe_all( - lambda msg, topic: all_received.append((msg, topic)) - if _topic_matches_prefix(topic) - else None + lambda msg, topic: all_collector(msg, topic) if _topic_matches_prefix(topic) else None ) sub.subscribe(topic1, lambda msg, topic: specific_received.append((msg, topic))) - time.sleep(0.01) # Allow subscriptions to be ready + time.sleep(0.01) # Allow subscriptions to register pub.publish(topic1, value1) pub.publish(topic2, value2) - time.sleep(0.01) + all_collector.wait() # subscribe_all gets both - assert len(all_received) == 2 + assert len(all_collector.results) == 2 # specific subscription gets only topic1 assert len(specific_received) == 1 @@ -214,25 +211,24 @@ def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None: def test_subscribe_glob(tc: Case[Any, Any]) -> None: """Test that glob pattern subscriptions receive only matching topics.""" for pattern_topic, expected_indices in tc.glob_patterns: - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(expected_indices)) with tc.pubsub_context() as (pub, sub): - sub.subscribe(pattern_topic, lambda msg, topic, r=received: r.append((msg, topic))) - time.sleep(0.01) # Allow subscription to be ready + sub.subscribe(pattern_topic, collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(expected_indices), ( + assert len(collector.results) == len(expected_indices), ( f"Expected {len(expected_indices)} messages for pattern {pattern_topic}, " - f"got {len(received)}" + f"got {len(collector.results)}" ) - # Verify we received the expected messages expected_msgs = [tc.topic_values[i][1] for i in expected_indices] - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] for expected in expected_msgs: assert expected in received_msgs @@ -241,25 +237,23 @@ def test_subscribe_glob(tc: Case[Any, Any]) -> None: def test_subscribe_regex(tc: Case[Any, Any]) -> None: """Test that regex pattern subscriptions receive only matching topics.""" for pattern_topic, expected_indices in tc.regex_patterns: - received: list[tuple[Any, Any]] = [] + collector = CallbackCollector(len(expected_indices)) with tc.pubsub_context() as (pub, sub): - sub.subscribe(pattern_topic, lambda msg, topic, r=received: r.append((msg, topic))) - - time.sleep(0.01) + sub.subscribe(pattern_topic, collector) + time.sleep(0.01) # Allow subscription to register for topic, value in tc.topic_values: pub.publish(topic, value) - time.sleep(0.01) + collector.wait() - assert len(received) == len(expected_indices), ( + assert len(collector.results) == len(expected_indices), ( f"Expected {len(expected_indices)} messages for pattern {pattern_topic}, " - f"got {len(received)}" + f"got {len(collector.results)}" ) - # Verify we received the expected messages expected_msgs = [tc.topic_values[i][1] for i in expected_indices] - received_msgs = [r[0] for r in received] + received_msgs = [r[0] for r in collector.results] for expected in expected_msgs: assert expected in received_msgs diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index e36741bbfd..0e61132c1c 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -17,7 +17,6 @@ import asyncio from collections.abc import Callable, Generator from contextlib import contextmanager -import threading import time from typing import Any @@ -26,6 +25,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.impl.memory import Memory +from dimos.utils.testing.collector import CallbackCollector @contextmanager @@ -148,26 +148,14 @@ def shared_memory_cpu_context() -> Generator[PickleSharedMemory, None, None]: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: with pubsub_context() as x: - # Create a list to capture received messages - received_messages: list[Any] = [] - msg_event = threading.Event() - - # Define callback function that stores received messages - def callback(message: Any, _: Any) -> None: - received_messages.append(message) - msg_event.set() - - # Subscribe to the topic with our callback - x.subscribe(topic, callback) + collector = CallbackCollector(1) - # Publish the first value to the topic + x.subscribe(topic, collector) x.publish(topic, values[0]) + collector.wait() - assert msg_event.wait(timeout=1.0), "Timed out waiting for message" - - # Verify the callback was called with the correct value - assert len(received_messages) == 1 - assert received_messages[0] == values[0] + assert len(collector.results) == 1 + assert collector.results[0][0] == values[0] @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @@ -176,36 +164,21 @@ def test_multiple_subscribers( ) -> None: """Test that multiple subscribers receive the same message.""" with pubsub_context() as x: - # Create lists to capture received messages for each subscriber - received_messages_1: list[Any] = [] - received_messages_2: list[Any] = [] - event_1 = threading.Event() - event_2 = threading.Event() - - # Define callback functions - def callback_1(message: Any, topic: Any) -> None: - received_messages_1.append(message) - event_1.set() + collector_1 = CallbackCollector(1) + collector_2 = CallbackCollector(1) - def callback_2(message: Any, topic: Any) -> None: - received_messages_2.append(message) - event_2.set() + x.subscribe(topic, collector_1) + x.subscribe(topic, collector_2) - # Subscribe both callbacks to the same topic - x.subscribe(topic, callback_1) - x.subscribe(topic, callback_2) - - # Publish the first value x.publish(topic, values[0]) - assert event_1.wait(timeout=1.0), "Timed out waiting for subscriber 1" - assert event_2.wait(timeout=1.0), "Timed out waiting for subscriber 2" + collector_1.wait() + collector_2.wait() - # Verify both callbacks received the message - assert len(received_messages_1) == 1 - assert received_messages_1[0] == values[0] - assert len(received_messages_2) == 1 - assert received_messages_2[0] == values[0] + assert len(collector_1.results) == 1 + assert collector_1.results[0][0] == values[0] + assert len(collector_2.results) == 1 + assert collector_2.results[0][0] == values[0] @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @@ -241,28 +214,17 @@ def test_multiple_messages( ) -> None: """Test that subscribers receive multiple messages in order.""" with pubsub_context() as x: - # Create a list to capture received messages - received_messages: list[Any] = [] - all_received = threading.Event() - - # Publish the rest of the values (after the first one used in basic tests) messages_to_send = values[1:] if len(values) > 1 else values + collector = CallbackCollector(len(messages_to_send)) - # Define callback function - def callback(message: Any, topic: Any) -> None: - received_messages.append(message) - if len(received_messages) >= len(messages_to_send): - all_received.set() - - # Subscribe to the topic - x.subscribe(topic, callback) + x.subscribe(topic, collector) for msg in messages_to_send: x.publish(topic, msg) - assert all_received.wait(timeout=1.0), "Timed out waiting for all messages" + collector.wait() - # Verify all messages were received in order + received_messages = [r[0] for r in collector.results] assert len(received_messages) == len(messages_to_send) assert received_messages == messages_to_send diff --git a/dimos/utils/testing/collector.py b/dimos/utils/testing/collector.py new file mode 100644 index 0000000000..bcc3150e73 --- /dev/null +++ b/dimos/utils/testing/collector.py @@ -0,0 +1,50 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Callback collector with Event-based synchronization for async tests.""" + +import threading +from typing import Any + + +class CallbackCollector: + """Callable that collects ``(msg, topic)`` pairs and signals when *n* arrive. + + Designed as a drop-in subscription callback for pubsub tests:: + + collector = CallbackCollector(3) + sub.subscribe(topic, collector) + # ... publish 3 messages ... + collector.wait() + assert len(collector.results) == 3 + """ + + def __init__(self, n: int, timeout: float = 2.0) -> None: + self.results: list[tuple[Any, Any]] = [] + self._done = threading.Event() + self._n = n + self.timeout = timeout + + def __call__(self, msg: Any, topic: Any) -> None: + self.results.append((msg, topic)) + if len(self.results) >= self._n: + self._done.set() + + def wait(self) -> None: + """Block until *n* items have been collected, or *timeout* expires.""" + if not self._done.wait(self.timeout): + raise AssertionError( + f"Timed out after {self.timeout}s waiting for {self._n} messages " + f"(got {len(self.results)})" + )