Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions packages/python/slop-ai/src/slop_ai/discovery/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,59 @@
from .relay_transport import BridgeRelayTransport


def _normalize_match(value: str) -> str:
return "".join(ch.lower() for ch in value if ch.isalnum())


def _resolve_named_match(
items: list[Any],
id_or_name: str,
*,
get_id: Callable[[Any], str],
get_name: Callable[[Any], str],
) -> Any | None:
raw = id_or_name.strip()
if not raw:
return None

lower = raw.lower()
normalized = _normalize_match(raw)

exact_id = [item for item in items if get_id(item) == raw]
if len(exact_id) == 1:
return exact_id[0]

exact_text = [
item
for item in items
if get_id(item).lower() == lower or get_name(item).lower() == lower
]
if len(exact_text) == 1:
return exact_text[0]

exact_normalized = [
item
for item in items
if _normalize_match(get_id(item)) == normalized
or _normalize_match(get_name(item)) == normalized
]
if len(exact_normalized) == 1:
return exact_normalized[0]

partial = [
item
for item in items
if lower in get_id(item).lower()
or lower in get_name(item).lower()
or (len(normalized) >= 2 and normalized in _normalize_match(get_id(item)))
or (len(normalized) >= 2 and normalized in _normalize_match(get_name(item)))
]
if len(partial) == 1:
return partial[0]

return None


class DiscoveryService:
"""Discover, connect, and manage local and browser-backed SLOP providers."""

Expand Down Expand Up @@ -129,7 +182,7 @@ def get_provider(self, provider_id: str) -> ConnectedProvider | None:
provider = self._providers.get(provider_id)
if provider is None or provider.status != "connected":
return None
self._last_accessed[provider.id] = asyncio.get_running_loop().time()
self._touch_provider(provider.id)
return provider

async def ensure_connected(self, id_or_name: str) -> ConnectedProvider | None:
Expand Down Expand Up @@ -426,35 +479,32 @@ def _create_transport(self, descriptor: ProviderDescriptor) -> Any | None:
return None

def _find_connected_provider(self, id_or_name: str) -> ConnectedProvider | None:
provider = self._providers.get(id_or_name)
provider = _resolve_named_match(
list(self._providers.values()),
id_or_name,
get_id=lambda item: item.id,
get_name=lambda item: item.name,
)
if provider is not None and provider.status == "connected":
self._last_accessed[provider.id] = asyncio.get_running_loop().time()
self._touch_provider(provider.id)
return provider

needle = id_or_name.lower()
for provider in self._providers.values():
if provider.status == "connected" and needle in provider.name.lower():
self._last_accessed[provider.id] = asyncio.get_running_loop().time()
return provider
return None

def _find_any_provider(self, id_or_name: str) -> ConnectedProvider | None:
provider = self._providers.get(id_or_name)
if provider is not None:
return provider

needle = id_or_name.lower()
for provider in self._providers.values():
if needle in provider.name.lower():
return provider
return None
return _resolve_named_match(
list(self._providers.values()),
id_or_name,
get_id=lambda item: item.id,
get_name=lambda item: item.name,
)

def _find_descriptor(self, id_or_name: str) -> ProviderDescriptor | None:
needle = id_or_name.lower()
for descriptor in self.get_discovered():
if descriptor.id == id_or_name or needle in descriptor.name.lower():
return descriptor
return None
return _resolve_named_match(
self.get_discovered(),
id_or_name,
get_id=lambda item: item.id,
get_name=lambda item: item.name,
)

def _forget_provider(self, provider_id: str) -> None:
self._providers.pop(provider_id, None)
Expand All @@ -464,6 +514,10 @@ def _forget_provider(self, provider_id: str) -> None:
if reconnect_task is not None:
reconnect_task.cancel()

def _touch_provider(self, provider_id: str) -> None:
with contextlib.suppress(RuntimeError):
self._last_accessed[provider_id] = asyncio.get_running_loop().time()

def _fire_state_change(self) -> None:
for handler in self._state_change_handlers:
handler()
Expand Down
70 changes: 69 additions & 1 deletion packages/python/slop-ai/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import socket
from pathlib import Path
from typing import Any
from typing import Any, cast

import pytest

Expand All @@ -16,6 +16,11 @@
DiscoveryOptions,
DiscoveryService,
)
from slop_ai.discovery.models import (
ConnectedProvider,
ProviderDescriptor,
TransportDescriptor,
)


def test_service_scans_and_prunes_descriptors(tmp_path: Path) -> None:
Expand Down Expand Up @@ -113,6 +118,64 @@ async def _run() -> None:
asyncio.run(_run())


def test_service_matching_rejects_empty_and_ambiguous_app_names() -> None:
async def _run() -> None:
service = DiscoveryService(DiscoveryOptions(providers_dirs=[]))
service._local_descriptors = [
ProviderDescriptor(
id="canvas-app",
name="Canvas App",
slop_version="0.1",
transport=TransportDescriptor(type="unix", path="/tmp/canvas.sock"),
capabilities=["state"],
),
ProviderDescriptor(
id="canvas-pro",
name="Canvas Pro",
slop_version="0.1",
transport=TransportDescriptor(type="unix", path="/tmp/canvas-pro.sock"),
capabilities=["state"],
),
]
service._providers = {
"canvas-app": ConnectedProvider(
id="canvas-app",
name="Canvas App",
descriptor=service._local_descriptors[0],
consumer=cast(Any, _FakeConsumer()),
subscription_id="sub-1",
status="connected",
)
}

assert service._find_descriptor("") is None
assert service._find_descriptor(" ") is None
assert service._find_connected_provider("") is None
assert service._find_any_provider(" ") is None
assert service._find_descriptor("canvas") is None
assert service._find_descriptor("canvas-app") is not None
assert service._find_descriptor("Canvas App") is not None
assert service._find_connected_provider("Canvas App") is not None

asyncio.run(_run())


def test_service_matching_supports_normalized_unique_matches() -> None:
service = DiscoveryService(DiscoveryOptions(providers_dirs=[]))
descriptor = ProviderDescriptor(
id="whiteboard-canvas",
name="Whiteboard Canvas",
slop_version="0.1",
transport=TransportDescriptor(type="unix", path="/tmp/whiteboard.sock"),
capabilities=["state"],
)
service._local_descriptors = [descriptor]

assert service._find_descriptor("whiteboard canvas") == descriptor
assert service._find_descriptor("WhiteboardCanvas") == descriptor
assert service._find_descriptor("board") == descriptor


async def _wait_until(predicate: Any, timeout: float = 1.0) -> None:
deadline = asyncio.get_running_loop().time() + timeout
while asyncio.get_running_loop().time() < deadline:
Expand Down Expand Up @@ -165,3 +228,8 @@ def start(self) -> None:

async def stop(self) -> None:
return None


class _FakeConsumer:
def disconnect(self) -> None:
return None
Loading