Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions dimos/protocol/pubsub/impl/test_lcmpubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Generator
from collections.abc import Iterator
import time
from typing import Any

Expand All @@ -29,32 +29,28 @@
)
from dimos.utils.testing.collector import CallbackCollector

# Isolated multicast group so stale messages from other tests
# (which use the default 239.255.76.67:7667) don't leak in.
_ISOLATED_LCM_URL = "udpm://239.255.76.98:7698?ttl=0"


@pytest.fixture
def lcm_pub_sub_base() -> Generator[LCMPubSubBase, None, None]:
lcm = LCMPubSubBase(url=_ISOLATED_LCM_URL)
def lcm_pub_sub_base(lcm_url: str) -> Iterator[LCMPubSubBase]:
lcm = LCMPubSubBase(url=lcm_url)
lcm.start()
time.sleep(0.05) # let the handler thread enter the LCM loop
yield lcm
lcm.stop()


@pytest.fixture
def pickle_lcm() -> Generator[PickleLCM, None, None]:
lcm = PickleLCM(url=_ISOLATED_LCM_URL)
def pickle_lcm(lcm_url: str) -> Iterator[PickleLCM]:
lcm = PickleLCM(url=lcm_url)
lcm.start()
time.sleep(0.05) # let the handler thread enter the LCM loop
yield lcm
lcm.stop()


@pytest.fixture
def lcm() -> Generator[LCM, None, None]:
lcm = LCM(url=_ISOLATED_LCM_URL)
def lcm(lcm_url: str) -> Iterator[LCM]:
lcm = LCM(url=lcm_url)
lcm.start()
time.sleep(0.05) # let the handler thread enter the LCM loop
yield lcm
Expand Down
40 changes: 18 additions & 22 deletions dimos/protocol/pubsub/test_pattern_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Grid tests for subscribe_all pattern subscriptions."""

from collections.abc import Callable, Generator
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
import re
Expand Down Expand Up @@ -44,22 +44,18 @@ class Case(Generic[TopicT, MsgT]):
"""Test case for grid testing pubsub implementations."""

name: str
pubsub_context: Callable[[], AbstractContextManager[PubSubPair[TopicT, MsgT]]]
pubsub_context: Callable[[str], AbstractContextManager[PubSubPair[TopicT, MsgT]]]
topic_values: list[tuple[TopicT, MsgT]]
tags: set[str] = field(default_factory=set)
# Pattern tests: (pattern_topic, {indices of topic_values that should match})
glob_patterns: list[tuple[TopicT, set[int]]] = field(default_factory=list)
regex_patterns: list[tuple[TopicT, set[int]]] = field(default_factory=list)


# Use an isolated multicast group to avoid cross-test LCM contamination.
_ISOLATED_LCM_URL = "udpm://239.255.76.99:7699?ttl=0"


@contextmanager
def lcm_typed_context() -> Generator[tuple[LCM, LCM], None, None]:
pub = LCM(url=_ISOLATED_LCM_URL)
sub = LCM(url=_ISOLATED_LCM_URL)
def lcm_typed_context(url: str) -> Iterator[tuple[LCM, LCM]]:
pub = LCM(url=url)
sub = LCM(url=url)
pub.start()
sub.start()
try:
Expand All @@ -70,9 +66,9 @@ def lcm_typed_context() -> Generator[tuple[LCM, LCM], None, None]:


@contextmanager
def lcm_bytes_context() -> Generator[tuple[LCMPubSubBase, LCMPubSubBase], None, None]:
pub = LCMPubSubBase(url=_ISOLATED_LCM_URL)
sub = LCMPubSubBase(url=_ISOLATED_LCM_URL)
def lcm_bytes_context(url: str) -> Iterator[tuple[LCMPubSubBase, LCMPubSubBase]]:
pub = LCMPubSubBase(url=url)
sub = LCMPubSubBase(url=url)
pub.start()
sub.start()
try:
Expand Down Expand Up @@ -142,11 +138,11 @@ def _topic_matches_prefix(topic: Any, prefix: str = "/") -> bool:


@pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name)
def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None:
def test_subscribe_all_receives_all_topics(tc: Case[Any, Any], lcm_url: str) -> None:
"""Test that subscribe_all receives messages from all topics."""
collector = CallbackCollector(len(tc.topic_values))

with tc.pubsub_context() as (pub, sub):
with tc.pubsub_context(lcm_url) as (pub, sub):
sub.subscribe_all(collector)
time.sleep(0.01) # Allow subscription to register

Expand All @@ -164,12 +160,12 @@ def test_subscribe_all_receives_all_topics(tc: Case[Any, Any]) -> None:


@pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name)
def test_subscribe_all_unsubscribe(tc: Case[Any, Any]) -> None:
def test_subscribe_all_unsubscribe(tc: Case[Any, Any], lcm_url: str) -> None:
"""Test that unsubscribe stops receiving messages."""
collector = CallbackCollector(1)
topic, value = tc.topic_values[0]

with tc.pubsub_context() as (pub, sub):
with tc.pubsub_context(lcm_url) as (pub, sub):
unsub = sub.subscribe_all(collector)
time.sleep(0.01) # Allow subscription to register

Expand All @@ -185,14 +181,14 @@ def test_subscribe_all_unsubscribe(tc: Case[Any, Any]) -> None:


@pytest.mark.parametrize("tc", all_cases, ids=lambda c: c.name)
def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None:
def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any], lcm_url: str) -> None:
"""Test that subscribe_all coexists with regular subscriptions."""
all_collector = CallbackCollector(2)
specific_received: list[tuple[Any, Any]] = []
topic1, value1 = tc.topic_values[0]
topic2, value2 = tc.topic_values[1]

with tc.pubsub_context() as (pub, sub):
with tc.pubsub_context(lcm_url) as (pub, sub):
sub.subscribe_all(
lambda msg, topic: all_collector(msg, topic) if _topic_matches_prefix(topic) else None
)
Expand All @@ -212,12 +208,12 @@ def test_subscribe_all_with_regular_subscribe(tc: Case[Any, Any]) -> None:


@pytest.mark.parametrize("tc", glob_cases, ids=lambda c: c.name)
def test_subscribe_glob(tc: Case[Any, Any]) -> None:
def test_subscribe_glob(tc: Case[Any, Any], lcm_url: str) -> None:
"""Test that glob pattern subscriptions receive only matching topics."""
for pattern_topic, expected_indices in tc.glob_patterns:
collector = CallbackCollector(len(expected_indices))

with tc.pubsub_context() as (pub, sub):
with tc.pubsub_context(lcm_url) as (pub, sub):
sub.subscribe(pattern_topic, collector)
time.sleep(0.01) # Allow subscription to register

Expand All @@ -238,12 +234,12 @@ def test_subscribe_glob(tc: Case[Any, Any]) -> None:


@pytest.mark.parametrize("tc", regex_cases, ids=lambda c: c.name)
def test_subscribe_regex(tc: Case[Any, Any]) -> None:
def test_subscribe_regex(tc: Case[Any, Any], lcm_url: str) -> None:
"""Test that regex pattern subscriptions receive only matching topics."""
for pattern_topic, expected_indices in tc.regex_patterns:
collector = CallbackCollector(len(expected_indices))

with tc.pubsub_context() as (pub, sub):
with tc.pubsub_context(lcm_url) as (pub, sub):
sub.subscribe(pattern_topic, collector)
time.sleep(0.01) # Allow subscription to register

Expand Down
Loading