diff --git a/core/libs/commonwealth/src/commonwealth/utils/tree.py b/core/libs/commonwealth/src/commonwealth/utils/tree.py new file mode 100644 index 0000000000..e5febaa9ce --- /dev/null +++ b/core/libs/commonwealth/src/commonwealth/utils/tree.py @@ -0,0 +1,72 @@ +from typing import Any, Callable, Dict, List + +import zenoh + + +class TreeNode: + def __init__(self, segment: str): + self.segment = segment + self.children: Dict[str, "TreeNode"] = {} + self.is_valid = False + self.methods: Dict[str, Callable[..., Any]] = {} + + def add_child(self, child: "TreeNode") -> "TreeNode": + if child.segment in self.children: + return self.children[child.segment] + self.children[child.segment] = child + return child + + def get_methods(self) -> Dict[str, Callable[..., Any]]: + return self.methods + + def add_node(self, segments: List[str], method: str, func: Callable[..., Any]) -> None: + if self.segment != segments[0]: + return + + if len(segments) == 1: + self.is_valid = True + self.methods[method] = func + return + + child = self.add_child(TreeNode(segments[1])) + child.add_node(segments[1:], method, func) + + def process_path(self, path: str, method: str, func: Callable[..., Any]) -> None: + segments = path.split("/") + self.add_node(segments, method, func) + + def get_match(self, path: str) -> tuple[str, "TreeNode"] | None: + segments = path.split("/") + + if self.segment != segments[0]: + return None + + node = self + matched_path = node.segment + + for segment in segments[1:]: + try: + node = node.children[segment] + matched_path += "/" + node.segment + except KeyError: + try: + node = node.children["*"] + matched_path += "/" + node.segment + except KeyError: + return None + + if node.is_valid: + return matched_path, node + return None + + def get_corresponding_method(self, sampleKind: zenoh.SampleKind) -> Callable[..., Any] | None: + if sampleKind == zenoh.SampleKind.DELETE: + keyword = "DELETE" + else: + methods = [method for method, _ in self.methods.items() if method != "DELETE"] + keyword = methods[0] + + try: + return self.methods[keyword] + except KeyError: + return None diff --git a/core/libs/commonwealth/src/commonwealth/utils/zenoh_helper.py b/core/libs/commonwealth/src/commonwealth/utils/zenoh_helper.py index 71914cf469..bd538912a7 100644 --- a/core/libs/commonwealth/src/commonwealth/utils/zenoh_helper.py +++ b/core/libs/commonwealth/src/commonwealth/utils/zenoh_helper.py @@ -1,17 +1,35 @@ +import ast import asyncio +import inspect import json import re from concurrent.futures import ThreadPoolExecutor +from functools import wraps from typing import Any, Callable import fastapi import zenoh +from commonwealth.utils.tree import TreeNode from fastapi.routing import APIRoute from loguru import logger +from starlette.responses import StreamingResponse from .Singleton import Singleton -PARAM_REGEX = r"{[a-zA-Z0-9_]+}" +PARAM_REGEX = r"{[a-zA-Z0-9_:]+}" +RESPONSE_PREFIX = "response/" + + +def _async_to_sync(async_func: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator to convert an async function to a sync function. + """ + + @wraps(async_func) + def wrapper() -> None: + asyncio.run(async_func()) + + return wrapper class ZenohSession(metaclass=Singleton): @@ -66,24 +84,23 @@ def zenoh_config(self, service_name: str) -> None: class ZenohRouter: prefix: str zenoh_session: ZenohSession + tree: TreeNode def __init__(self, service_name: str): self.prefix = service_name self.zenoh_session = ZenohSession(service_name) + self.tree = TreeNode(service_name) def add_queryable(self, path: str, func: Callable[..., Any]) -> None: - full_path = self.prefix - if path: - full_path += f"/{path}" - def wrapper(query: zenoh.Query) -> None: params = dict(query.parameters) # type: ignore + @_async_to_sync async def _handle_async() -> None: try: response = await func(**params) if response is not None: - query.reply(query.selector.key_expr, json.dumps(response, default=str)) + query.reply(query.selector.key_expr, handle_json(response)) except Exception as e: logger.exception(f"Error in zenoh query handler: {query.selector.key_expr}") error_response = { @@ -92,34 +109,182 @@ async def _handle_async() -> None: } query.reply(query.selector.key_expr, json.dumps(error_response)) - def run_async() -> None: - asyncio.run(_handle_async()) + self.zenoh_session.submit_to_executor(_handle_async) + + if self.zenoh_session.session: + self.zenoh_session.session.declare_queryable(path, wrapper) - self.zenoh_session.submit_to_executor(run_async) + def add_subscriber(self, path: str, func: Callable[..., Any]) -> None: + def wrapper(sample: zenoh.Sample) -> None: + if not self._should_process(sample, path, func): + return + + @_async_to_sync + async def _handle_async() -> None: + try: + parameters = get_parameters(sample, func) + result = await func(**parameters) + if result is not None: + await self.process_subscriber_response(result, sample) + except Exception as e: + logger.exception(f"Error in zenoh subscriber handler on {path}: {sample.kind=}, {str(e)}") + + self.zenoh_session.submit_to_executor(_handle_async) if self.zenoh_session.session: - self.zenoh_session.session.declare_queryable(full_path, wrapper) + self.zenoh_session.session.declare_subscriber(path, wrapper) + + async def process_subscriber_response(self, result: Any, sample: zenoh.Sample) -> None: + if self.zenoh_session.session is None: + return + + response_key = RESPONSE_PREFIX + str(sample.key_expr) + logger.info(f"The response for {sample.key_expr} will be published to {response_key}.") + + with self.zenoh_session.session.declare_publisher(response_key) as publisher: + if isinstance(result, StreamingResponse): + await self._handle_streaming_response(result, publisher) + else: + publisher.put(handle_json(result)) + + async def _handle_streaming_response(self, result: StreamingResponse, publisher: zenoh.Publisher) -> None: + async for chunk in result.body_iterator: + if isinstance(chunk, (dict, list)): + chunk_data = json.dumps(chunk, default=str) + elif isinstance(chunk, bytes): + chunk_data = chunk.decode("utf-8", errors="replace") + else: + chunk_data = str(chunk) + publisher.put(chunk_data) def add_routes_to_zenoh(self, app: fastapi.FastAPI) -> None: - queryables = [] + methods = self._get_methods(app) + for method, path, func in methods: + full_path = self.get_route_path(path) + self.tree.process_path(full_path, method, func) + + if method == "GET" and not is_streaming_response(func): + self.add_queryable(full_path, func) + else: + self.add_subscriber(full_path, func) + + def _get_methods(self, app: fastapi.FastAPI) -> list[tuple[str, str, Callable[..., Any]]]: + methods = [] for route in app.router.routes: - route_type = type(route) if ( isinstance(route, APIRoute) - and route_type.__name__ == "VersionedAPIRoute" - and "fastapi_versioning" in route_type.__module__ - and "GET" in route.methods + and type(route).__name__ == "VersionedAPIRoute" + and "fastapi_versioning" in type(route).__module__ ): - queryables.append((clean_path(route.path), route.endpoint)) + method_type = next(iter(route.methods), None) + if method_type and method_type in ("GET", "POST", "PUT", "DELETE"): + methods.append((method_type, clean_path(route.path), route.endpoint)) + return methods + + def get_route_path(self, path: str) -> str: + return self.prefix + "/" + path if path else self.prefix + + def _should_process(self, sample: zenoh.Sample, endpoint: str, func: Callable[..., Any]) -> bool: + matched = self.tree.get_match(str(sample.key_expr)) + + if matched is None: + return False + + matched_path, node = matched + if matched_path != endpoint: + return False + + found_method = node.get_corresponding_method(sample.kind) + if found_method is None: + return False + return found_method == func + - for path, func in queryables: - self.add_queryable(path, func) +def is_streaming_response(func: Callable[..., Any]) -> bool: + signature = inspect.signature(func) + return_annotation = signature.return_annotation + + if return_annotation is inspect.Signature.empty: + return False + + if inspect.isclass(return_annotation): + return issubclass(return_annotation, StreamingResponse) + + return False def clean_path(path: str) -> str: path = path.removeprefix("/").removesuffix("/") - zenoh_path = re.sub(PARAM_REGEX, "*", path) - zenoh_path = zenoh_path.replace("*/*", "**") return zenoh_path + + +def get_parameters(sample: zenoh.Sample, func: Callable[..., Any]) -> dict[str, Any]: + source = None + if sample.kind == zenoh.SampleKind.PUT: + source = sample.payload + else: + source = sample.attachment + + if source is None: + return {} + + parameters = parameters_process(source.to_string()) + return parameters_type_validation(parameters, func) + + +def parameters_process(parameters: str) -> dict[str, str]: + if not parameters: + return {} + + result = {} + for parameter in parameters.split(";"): + parameter = parameter.strip() + if not parameter: + continue + + if "=" not in parameter: + logger.warning(f"Skipping malformated parameter (no '=' found): {parameter}") + continue + + key, value = parameter.split("=", maxsplit=1) + key, value = key.strip(), value.strip() + + if key: + result[key] = value + return result + + +def parameters_type_validation(parameters: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]: + signature = inspect.signature(func) + typed_parameters = {} + + for key, value in parameters.items(): + if key not in signature.parameters: + continue + + annotation = signature.parameters[key].annotation + try: + if is_primitive(annotation): + typed_parameters[key] = annotation(value) + else: + parsed = ast.literal_eval(value) + typed_parameters[key] = annotation(**parsed) + except Exception as e: + logger.warning(f"Error converting parameter {key} to type {annotation}: {e}") + continue + + return typed_parameters + + +def is_primitive(value: Any) -> bool: + return value in (int, float, str, bool) + + +def handle_json(data: Any) -> str: + try: + return json.dumps(data, default=str) + except (TypeError, ValueError) as e: + logger.error(f"Error serializing data to JSON: {data}, {e}") + return '{"error": "Serialization failed"}'