diff --git a/packages/python/slop-ai/src/slop_ai/discovery/service.py b/packages/python/slop-ai/src/slop_ai/discovery/service.py index ba4ea46..abde3e9 100644 --- a/packages/python/slop-ai/src/slop_ai/discovery/service.py +++ b/packages/python/slop-ai/src/slop_ai/discovery/service.py @@ -25,6 +25,59 @@ from .relay_transport import BridgeRelayTransport +def _normalize_match(value: str) -> str: + return "".join(ch.lower() for ch in value if ch.isalnum()) + + +def _resolve_named_match( + items: list[Any], + id_or_name: str, + *, + get_id: Callable[[Any], str], + get_name: Callable[[Any], str], +) -> Any | None: + raw = id_or_name.strip() + if not raw: + return None + + lower = raw.lower() + normalized = _normalize_match(raw) + + exact_id = [item for item in items if get_id(item) == raw] + if len(exact_id) == 1: + return exact_id[0] + + exact_text = [ + item + for item in items + if get_id(item).lower() == lower or get_name(item).lower() == lower + ] + if len(exact_text) == 1: + return exact_text[0] + + exact_normalized = [ + item + for item in items + if _normalize_match(get_id(item)) == normalized + or _normalize_match(get_name(item)) == normalized + ] + if len(exact_normalized) == 1: + return exact_normalized[0] + + partial = [ + item + for item in items + if lower in get_id(item).lower() + or lower in get_name(item).lower() + or (len(normalized) >= 2 and normalized in _normalize_match(get_id(item))) + or (len(normalized) >= 2 and normalized in _normalize_match(get_name(item))) + ] + if len(partial) == 1: + return partial[0] + + return None + + class DiscoveryService: """Discover, connect, and manage local and browser-backed SLOP providers.""" @@ -129,7 +182,7 @@ def get_provider(self, provider_id: str) -> ConnectedProvider | None: provider = self._providers.get(provider_id) if provider is None or provider.status != "connected": return None - self._last_accessed[provider.id] = asyncio.get_running_loop().time() + self._touch_provider(provider.id) return provider async def ensure_connected(self, id_or_name: str) -> ConnectedProvider | None: @@ -426,35 +479,32 @@ def _create_transport(self, descriptor: ProviderDescriptor) -> Any | None: return None def _find_connected_provider(self, id_or_name: str) -> ConnectedProvider | None: - provider = self._providers.get(id_or_name) + provider = _resolve_named_match( + list(self._providers.values()), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) if provider is not None and provider.status == "connected": - self._last_accessed[provider.id] = asyncio.get_running_loop().time() + self._touch_provider(provider.id) return provider - - needle = id_or_name.lower() - for provider in self._providers.values(): - if provider.status == "connected" and needle in provider.name.lower(): - self._last_accessed[provider.id] = asyncio.get_running_loop().time() - return provider return None def _find_any_provider(self, id_or_name: str) -> ConnectedProvider | None: - provider = self._providers.get(id_or_name) - if provider is not None: - return provider - - needle = id_or_name.lower() - for provider in self._providers.values(): - if needle in provider.name.lower(): - return provider - return None + return _resolve_named_match( + list(self._providers.values()), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) def _find_descriptor(self, id_or_name: str) -> ProviderDescriptor | None: - needle = id_or_name.lower() - for descriptor in self.get_discovered(): - if descriptor.id == id_or_name or needle in descriptor.name.lower(): - return descriptor - return None + return _resolve_named_match( + self.get_discovered(), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) def _forget_provider(self, provider_id: str) -> None: self._providers.pop(provider_id, None) @@ -464,6 +514,10 @@ def _forget_provider(self, provider_id: str) -> None: if reconnect_task is not None: reconnect_task.cancel() + def _touch_provider(self, provider_id: str) -> None: + with contextlib.suppress(RuntimeError): + self._last_accessed[provider_id] = asyncio.get_running_loop().time() + def _fire_state_change(self) -> None: for handler in self._state_change_handlers: handler() diff --git a/packages/python/slop-ai/tests/test_discovery.py b/packages/python/slop-ai/tests/test_discovery.py index 8dcbab5..60d6491 100644 --- a/packages/python/slop-ai/tests/test_discovery.py +++ b/packages/python/slop-ai/tests/test_discovery.py @@ -6,7 +6,7 @@ import json import socket from pathlib import Path -from typing import Any +from typing import Any, cast import pytest @@ -16,6 +16,11 @@ DiscoveryOptions, DiscoveryService, ) +from slop_ai.discovery.models import ( + ConnectedProvider, + ProviderDescriptor, + TransportDescriptor, +) def test_service_scans_and_prunes_descriptors(tmp_path: Path) -> None: @@ -113,6 +118,64 @@ async def _run() -> None: asyncio.run(_run()) +def test_service_matching_rejects_empty_and_ambiguous_app_names() -> None: + async def _run() -> None: + service = DiscoveryService(DiscoveryOptions(providers_dirs=[])) + service._local_descriptors = [ + ProviderDescriptor( + id="canvas-app", + name="Canvas App", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/canvas.sock"), + capabilities=["state"], + ), + ProviderDescriptor( + id="canvas-pro", + name="Canvas Pro", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/canvas-pro.sock"), + capabilities=["state"], + ), + ] + service._providers = { + "canvas-app": ConnectedProvider( + id="canvas-app", + name="Canvas App", + descriptor=service._local_descriptors[0], + consumer=cast(Any, _FakeConsumer()), + subscription_id="sub-1", + status="connected", + ) + } + + assert service._find_descriptor("") is None + assert service._find_descriptor(" ") is None + assert service._find_connected_provider("") is None + assert service._find_any_provider(" ") is None + assert service._find_descriptor("canvas") is None + assert service._find_descriptor("canvas-app") is not None + assert service._find_descriptor("Canvas App") is not None + assert service._find_connected_provider("Canvas App") is not None + + asyncio.run(_run()) + + +def test_service_matching_supports_normalized_unique_matches() -> None: + service = DiscoveryService(DiscoveryOptions(providers_dirs=[])) + descriptor = ProviderDescriptor( + id="whiteboard-canvas", + name="Whiteboard Canvas", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/whiteboard.sock"), + capabilities=["state"], + ) + service._local_descriptors = [descriptor] + + assert service._find_descriptor("whiteboard canvas") == descriptor + assert service._find_descriptor("WhiteboardCanvas") == descriptor + assert service._find_descriptor("board") == descriptor + + async def _wait_until(predicate: Any, timeout: float = 1.0) -> None: deadline = asyncio.get_running_loop().time() + timeout while asyncio.get_running_loop().time() < deadline: @@ -165,3 +228,8 @@ def start(self) -> None: async def stop(self) -> None: return None + + +class _FakeConsumer: + def disconnect(self) -> None: + return None