From 88afc21c4808b8c856491c908d910159cf108d6b Mon Sep 17 00:00:00 2001 From: Diego Carlino Date: Tue, 7 Apr 2026 18:39:24 +0200 Subject: [PATCH] fix(python): harden discovery provider matching This prevents empty or ambiguous app names from silently matching the wrong provider and keeps access tracking safe when discovery is called outside an active event loop. It adds targeted discovery tests for empty, ambiguous, and normalized name resolution. --- .../slop-ai/src/slop_ai/discovery/service.py | 100 ++++++++++++++---- .../python/slop-ai/tests/test_discovery.py | 70 +++++++++++- 2 files changed, 146 insertions(+), 24 deletions(-) diff --git a/packages/python/slop-ai/src/slop_ai/discovery/service.py b/packages/python/slop-ai/src/slop_ai/discovery/service.py index ba4ea46..abde3e9 100644 --- a/packages/python/slop-ai/src/slop_ai/discovery/service.py +++ b/packages/python/slop-ai/src/slop_ai/discovery/service.py @@ -25,6 +25,59 @@ from .relay_transport import BridgeRelayTransport +def _normalize_match(value: str) -> str: + return "".join(ch.lower() for ch in value if ch.isalnum()) + + +def _resolve_named_match( + items: list[Any], + id_or_name: str, + *, + get_id: Callable[[Any], str], + get_name: Callable[[Any], str], +) -> Any | None: + raw = id_or_name.strip() + if not raw: + return None + + lower = raw.lower() + normalized = _normalize_match(raw) + + exact_id = [item for item in items if get_id(item) == raw] + if len(exact_id) == 1: + return exact_id[0] + + exact_text = [ + item + for item in items + if get_id(item).lower() == lower or get_name(item).lower() == lower + ] + if len(exact_text) == 1: + return exact_text[0] + + exact_normalized = [ + item + for item in items + if _normalize_match(get_id(item)) == normalized + or _normalize_match(get_name(item)) == normalized + ] + if len(exact_normalized) == 1: + return exact_normalized[0] + + partial = [ + item + for item in items + if lower in get_id(item).lower() + or lower in get_name(item).lower() + or (len(normalized) >= 2 and normalized in _normalize_match(get_id(item))) + or (len(normalized) >= 2 and normalized in _normalize_match(get_name(item))) + ] + if len(partial) == 1: + return partial[0] + + return None + + class DiscoveryService: """Discover, connect, and manage local and browser-backed SLOP providers.""" @@ -129,7 +182,7 @@ def get_provider(self, provider_id: str) -> ConnectedProvider | None: provider = self._providers.get(provider_id) if provider is None or provider.status != "connected": return None - self._last_accessed[provider.id] = asyncio.get_running_loop().time() + self._touch_provider(provider.id) return provider async def ensure_connected(self, id_or_name: str) -> ConnectedProvider | None: @@ -426,35 +479,32 @@ def _create_transport(self, descriptor: ProviderDescriptor) -> Any | None: return None def _find_connected_provider(self, id_or_name: str) -> ConnectedProvider | None: - provider = self._providers.get(id_or_name) + provider = _resolve_named_match( + list(self._providers.values()), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) if provider is not None and provider.status == "connected": - self._last_accessed[provider.id] = asyncio.get_running_loop().time() + self._touch_provider(provider.id) return provider - - needle = id_or_name.lower() - for provider in self._providers.values(): - if provider.status == "connected" and needle in provider.name.lower(): - self._last_accessed[provider.id] = asyncio.get_running_loop().time() - return provider return None def _find_any_provider(self, id_or_name: str) -> ConnectedProvider | None: - provider = self._providers.get(id_or_name) - if provider is not None: - return provider - - needle = id_or_name.lower() - for provider in self._providers.values(): - if needle in provider.name.lower(): - return provider - return None + return _resolve_named_match( + list(self._providers.values()), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) def _find_descriptor(self, id_or_name: str) -> ProviderDescriptor | None: - needle = id_or_name.lower() - for descriptor in self.get_discovered(): - if descriptor.id == id_or_name or needle in descriptor.name.lower(): - return descriptor - return None + return _resolve_named_match( + self.get_discovered(), + id_or_name, + get_id=lambda item: item.id, + get_name=lambda item: item.name, + ) def _forget_provider(self, provider_id: str) -> None: self._providers.pop(provider_id, None) @@ -464,6 +514,10 @@ def _forget_provider(self, provider_id: str) -> None: if reconnect_task is not None: reconnect_task.cancel() + def _touch_provider(self, provider_id: str) -> None: + with contextlib.suppress(RuntimeError): + self._last_accessed[provider_id] = asyncio.get_running_loop().time() + def _fire_state_change(self) -> None: for handler in self._state_change_handlers: handler() diff --git a/packages/python/slop-ai/tests/test_discovery.py b/packages/python/slop-ai/tests/test_discovery.py index 8dcbab5..60d6491 100644 --- a/packages/python/slop-ai/tests/test_discovery.py +++ b/packages/python/slop-ai/tests/test_discovery.py @@ -6,7 +6,7 @@ import json import socket from pathlib import Path -from typing import Any +from typing import Any, cast import pytest @@ -16,6 +16,11 @@ DiscoveryOptions, DiscoveryService, ) +from slop_ai.discovery.models import ( + ConnectedProvider, + ProviderDescriptor, + TransportDescriptor, +) def test_service_scans_and_prunes_descriptors(tmp_path: Path) -> None: @@ -113,6 +118,64 @@ async def _run() -> None: asyncio.run(_run()) +def test_service_matching_rejects_empty_and_ambiguous_app_names() -> None: + async def _run() -> None: + service = DiscoveryService(DiscoveryOptions(providers_dirs=[])) + service._local_descriptors = [ + ProviderDescriptor( + id="canvas-app", + name="Canvas App", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/canvas.sock"), + capabilities=["state"], + ), + ProviderDescriptor( + id="canvas-pro", + name="Canvas Pro", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/canvas-pro.sock"), + capabilities=["state"], + ), + ] + service._providers = { + "canvas-app": ConnectedProvider( + id="canvas-app", + name="Canvas App", + descriptor=service._local_descriptors[0], + consumer=cast(Any, _FakeConsumer()), + subscription_id="sub-1", + status="connected", + ) + } + + assert service._find_descriptor("") is None + assert service._find_descriptor(" ") is None + assert service._find_connected_provider("") is None + assert service._find_any_provider(" ") is None + assert service._find_descriptor("canvas") is None + assert service._find_descriptor("canvas-app") is not None + assert service._find_descriptor("Canvas App") is not None + assert service._find_connected_provider("Canvas App") is not None + + asyncio.run(_run()) + + +def test_service_matching_supports_normalized_unique_matches() -> None: + service = DiscoveryService(DiscoveryOptions(providers_dirs=[])) + descriptor = ProviderDescriptor( + id="whiteboard-canvas", + name="Whiteboard Canvas", + slop_version="0.1", + transport=TransportDescriptor(type="unix", path="/tmp/whiteboard.sock"), + capabilities=["state"], + ) + service._local_descriptors = [descriptor] + + assert service._find_descriptor("whiteboard canvas") == descriptor + assert service._find_descriptor("WhiteboardCanvas") == descriptor + assert service._find_descriptor("board") == descriptor + + async def _wait_until(predicate: Any, timeout: float = 1.0) -> None: deadline = asyncio.get_running_loop().time() + timeout while asyncio.get_running_loop().time() < deadline: @@ -165,3 +228,8 @@ def start(self) -> None: async def stop(self) -> None: return None + + +class _FakeConsumer: + def disconnect(self) -> None: + return None