Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions core/libs/commonwealth/src/commonwealth/utils/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any, Callable, Dict, List

import zenoh


class TreeNode:
def __init__(self, segment: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a description of how this class operates ? It's pretty complex just looking at the code.

self.segment = segment
self.children: Dict[str, "TreeNode"] = {}
self.is_valid = False
self.methods: Dict[str, Callable[..., Any]] = {}

def add_child(self, child: "TreeNode") -> "TreeNode":
if child.segment in self.children:
return self.children[child.segment]
self.children[child.segment] = child
return child
Comment on lines +14 to +17
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if child.segment in self.children:
return self.children[child.segment]
self.children[child.segment] = child
return child
if child.segment not in self.children:
self.children[child.segment] = child
return self.children[child.segment]


def get_methods(self) -> Dict[str, Callable[..., Any]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend to use _ prefix for private properties and avoid using get prefix in functions.

https://github.com/bluerobotics/software-guidelines/blob/master/guidelines/style.md

return self.methods

def add_node(self, segments: List[str], method: str, func: Callable[..., Any]) -> None:
if self.segment != segments[0]:
return

if len(segments) == 1:
self.is_valid = True
self.methods[method] = func
return

child = self.add_child(TreeNode(segments[1]))
child.add_node(segments[1:], method, func)

def process_path(self, path: str, method: str, func: Callable[..., Any]) -> None:
segments = path.split("/")
self.add_node(segments, method, func)

def get_match(self, path: str) -> tuple[str, "TreeNode"] | None:
segments = path.split("/")

if self.segment != segments[0]:
return None

node = self
matched_path = node.segment

for segment in segments[1:]:
try:
node = node.children[segment]
matched_path += "/" + node.segment
except KeyError:
try:
node = node.children["*"]
matched_path += "/" + node.segment
except KeyError:
return None

if node.is_valid:
return matched_path, node
return None

def get_corresponding_method(self, sampleKind: zenoh.SampleKind) -> Callable[..., Any] | None:
if sampleKind == zenoh.SampleKind.DELETE:
keyword = "DELETE"
else:
methods = [method for method, _ in self.methods.items() if method != "DELETE"]
keyword = methods[0]
Comment on lines +62 to +67
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Guard against empty or inconsistent method mappings when resolving the handler, to avoid IndexError and surprising routing.

In get_corresponding_method, when sampleKind != zenoh.SampleKind.DELETE you build methods = [...] and then access methods[0]. If self.methods is empty or only contains `

Comment on lines +63 to +67
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Clarify and tighten the mapping between SampleKind and HTTP methods to avoid arbitrary handler selection.

Here you special-case DELETE but otherwise just pick the first non-DELETE method from self.methods. If a node supports both GET and POST, the selected handler depends on dict insertion order, and different SampleKind values (e.g. PUT vs POST) end up mapped to the same handler. Consider defining an explicit SampleKind→HTTP method mapping, or at least a deterministic preference order (e.g. prefer PUT, then POST, etc.) to avoid dispatching to the wrong endpoint.

Suggested implementation:

        if node.is_valid:
            return matched_path, node
        return None

    def get_corresponding_method(self, sampleKind: zenoh.SampleKind) -> Callable[..., Any] | None:
        """
        Map a zenoh.SampleKind to an HTTP method in a deterministic way.

        1. Try an explicit SampleKind -> HTTP verb mapping.
        2. If the mapped verb is not supported by this node, fall back to a
           deterministic preference order among supported non-DELETE methods.
        3. As a final fallback, choose the first non-DELETE method in sorted
           order to keep behavior deterministic.
        """
        # Explicit mapping from SampleKind to preferred HTTP method
        samplekind_to_method: dict[zenoh.SampleKind, str] = {
            zenoh.SampleKind.DELETE: "DELETE",
            getattr(zenoh.SampleKind, "PUT", None): "PUT",
            getattr(zenoh.SampleKind, "PATCH", None): "PATCH",
        }

        # Remove any None keys in case PUT/PATCH do not exist in this zenoh version
        samplekind_to_method = {
            k: v for k, v in samplekind_to_method.items() if k is not None
        }

        keyword: str | None = samplekind_to_method.get(sampleKind)

        # If there is no direct mapping or the mapped method is not supported,
        # fall back to a deterministic preference order (excluding DELETE).
        if keyword is None or keyword not in self.methods:
            preference_order: tuple[str, ...] = ("PUT", "POST", "PATCH", "GET")
            keyword = None
            for method in preference_order:
                if method in self.methods and method != "DELETE":
                    keyword = method
                    break

            # As a last resort, pick the first non-DELETE method in sorted order
            if keyword is None:
                non_delete_methods = sorted(
                    m for m in self.methods.keys() if m != "DELETE"
                )
                if non_delete_methods:
                    keyword = non_delete_methods[0]

        if keyword is None:
            return None

        return self.methods.get(keyword)

If Callable and Any are not already imported in this file, add:

from typing import Any, Callable

near the other imports. Also ensure that zenoh.SampleKind.PUT and zenoh.SampleKind.PATCH exist in the version of zenoh you are using; the getattr usage in the mapping is there to avoid attribute errors on older versions.


try:
return self.methods[keyword]
except KeyError:
return None
205 changes: 185 additions & 20 deletions core/libs/commonwealth/src/commonwealth/utils/zenoh_helper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
import ast
import asyncio
import inspect
import json
import re
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any, Callable

import fastapi
import zenoh
from commonwealth.utils.tree import TreeNode
from fastapi.routing import APIRoute
from loguru import logger
from starlette.responses import StreamingResponse

from .Singleton import Singleton

PARAM_REGEX = r"{[a-zA-Z0-9_]+}"
PARAM_REGEX = r"{[a-zA-Z0-9_:]+}"
RESPONSE_PREFIX = "response/"


def _async_to_sync(async_func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator to convert an async function to a sync function.
"""

@wraps(async_func)
def wrapper() -> None:
asyncio.run(async_func())

return wrapper


class ZenohSession(metaclass=Singleton):
Expand Down Expand Up @@ -66,24 +84,23 @@ def zenoh_config(self, service_name: str) -> None:
class ZenohRouter:
prefix: str
zenoh_session: ZenohSession
tree: TreeNode

def __init__(self, service_name: str):
self.prefix = service_name
self.zenoh_session = ZenohSession(service_name)
self.tree = TreeNode(service_name)

def add_queryable(self, path: str, func: Callable[..., Any]) -> None:
full_path = self.prefix
if path:
full_path += f"/{path}"

def wrapper(query: zenoh.Query) -> None:
params = dict(query.parameters) # type: ignore

@_async_to_sync
async def _handle_async() -> None:
try:
response = await func(**params)
if response is not None:
query.reply(query.selector.key_expr, json.dumps(response, default=str))
query.reply(query.selector.key_expr, handle_json(response))
except Exception as e:
logger.exception(f"Error in zenoh query handler: {query.selector.key_expr}")
error_response = {
Expand All @@ -92,34 +109,182 @@ async def _handle_async() -> None:
}
query.reply(query.selector.key_expr, json.dumps(error_response))

def run_async() -> None:
asyncio.run(_handle_async())
self.zenoh_session.submit_to_executor(_handle_async)

if self.zenoh_session.session:
self.zenoh_session.session.declare_queryable(path, wrapper)

self.zenoh_session.submit_to_executor(run_async)
def add_subscriber(self, path: str, func: Callable[..., Any]) -> None:
def wrapper(sample: zenoh.Sample) -> None:
if not self._should_process(sample, path, func):
return

@_async_to_sync
async def _handle_async() -> None:
try:
parameters = get_parameters(sample, func)
result = await func(**parameters)
if result is not None:
await self.process_subscriber_response(result, sample)
except Exception as e:
logger.exception(f"Error in zenoh subscriber handler on {path}: {sample.kind=}, {str(e)}")

self.zenoh_session.submit_to_executor(_handle_async)

if self.zenoh_session.session:
self.zenoh_session.session.declare_queryable(full_path, wrapper)
self.zenoh_session.session.declare_subscriber(path, wrapper)

async def process_subscriber_response(self, result: Any, sample: zenoh.Sample) -> None:
if self.zenoh_session.session is None:
return

response_key = RESPONSE_PREFIX + str(sample.key_expr)
logger.info(f"The response for {sample.key_expr} will be published to {response_key}.")

with self.zenoh_session.session.declare_publisher(response_key) as publisher:
if isinstance(result, StreamingResponse):
await self._handle_streaming_response(result, publisher)
else:
publisher.put(handle_json(result))

async def _handle_streaming_response(self, result: StreamingResponse, publisher: zenoh.Publisher) -> None:
async for chunk in result.body_iterator:
if isinstance(chunk, (dict, list)):
chunk_data = json.dumps(chunk, default=str)
elif isinstance(chunk, bytes):
chunk_data = chunk.decode("utf-8", errors="replace")
else:
chunk_data = str(chunk)
publisher.put(chunk_data)

def add_routes_to_zenoh(self, app: fastapi.FastAPI) -> None:
queryables = []
methods = self._get_methods(app)
for method, path, func in methods:
full_path = self.get_route_path(path)
self.tree.process_path(full_path, method, func)

if method == "GET" and not is_streaming_response(func):
self.add_queryable(full_path, func)
else:
self.add_subscriber(full_path, func)

def _get_methods(self, app: fastapi.FastAPI) -> list[tuple[str, str, Callable[..., Any]]]:
methods = []
for route in app.router.routes:
route_type = type(route)
if (
isinstance(route, APIRoute)
and route_type.__name__ == "VersionedAPIRoute"
and "fastapi_versioning" in route_type.__module__
and "GET" in route.methods
and type(route).__name__ == "VersionedAPIRoute"
and "fastapi_versioning" in type(route).__module__
):
queryables.append((clean_path(route.path), route.endpoint))
method_type = next(iter(route.methods), None)
if method_type and method_type in ("GET", "POST", "PUT", "DELETE"):
methods.append((method_type, clean_path(route.path), route.endpoint))
return methods

def get_route_path(self, path: str) -> str:
return self.prefix + "/" + path if path else self.prefix

def _should_process(self, sample: zenoh.Sample, endpoint: str, func: Callable[..., Any]) -> bool:
matched = self.tree.get_match(str(sample.key_expr))

if matched is None:
return False

matched_path, node = matched
if matched_path != endpoint:
return False

found_method = node.get_corresponding_method(sample.kind)
if found_method is None:
return False
return found_method == func


for path, func in queryables:
self.add_queryable(path, func)
def is_streaming_response(func: Callable[..., Any]) -> bool:
signature = inspect.signature(func)
return_annotation = signature.return_annotation

if return_annotation is inspect.Signature.empty:
return False

if inspect.isclass(return_annotation):
return issubclass(return_annotation, StreamingResponse)

return False


def clean_path(path: str) -> str:
path = path.removeprefix("/").removesuffix("/")

zenoh_path = re.sub(PARAM_REGEX, "*", path)
zenoh_path = zenoh_path.replace("*/*", "**")

return zenoh_path


def get_parameters(sample: zenoh.Sample, func: Callable[..., Any]) -> dict[str, Any]:
source = None
if sample.kind == zenoh.SampleKind.PUT:
source = sample.payload
else:
source = sample.attachment

if source is None:
return {}

parameters = parameters_process(source.to_string())
return parameters_type_validation(parameters, func)


def parameters_process(parameters: str) -> dict[str, str]:
if not parameters:
return {}

result = {}
for parameter in parameters.split(";"):
parameter = parameter.strip()
if not parameter:
continue

if "=" not in parameter:
logger.warning(f"Skipping malformated parameter (no '=' found): {parameter}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Fix the typo in the warning message for malformed parameters.

Suggested change
logger.warning(f"Skipping malformated parameter (no '=' found): {parameter}")
logger.warning(f"Skipping malformed parameter (no '=' found): {parameter}")

continue

key, value = parameter.split("=", maxsplit=1)
key, value = key.strip(), value.strip()

if key:
result[key] = value
return result


def parameters_type_validation(parameters: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
signature = inspect.signature(func)
typed_parameters = {}

for key, value in parameters.items():
if key not in signature.parameters:
continue

annotation = signature.parameters[key].annotation
try:
Comment on lines +259 to +268
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Parameter type conversion logic is fragile for non-primitive annotations and may skip valid parameters.

In parameters_type_validation, any annotation outside (int, float, str, bool) is treated as a complex type, parsed with ast.literal_eval, and then instantiated via annotation(**parsed). This will fail for common cases such as:

  • list/dict/set (or list[str], etc.), which aren’t primitives but also don’t support annotation(**parsed), so you end up with warnings and dropped params.
  • Optional[int] / Union[int, None] and other typing constructs, which also fall into this branch and fail.

Consider special-casing container types (returning parsed when the structure already matches) and handling typing constructs more defensively, or falling back to the raw value when conversion is unclear, to avoid noisy warnings and unnecessary parameter drops.

Suggested implementation:

def parameters_type_validation(parameters: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
    signature = inspect.signature(func)
    typed_parameters: dict[str, Any] = {}

    # Local import to avoid introducing a hard dependency if typing.get_origin is not needed elsewhere
    try:
        from typing import get_origin, get_args  # type: ignore
    except ImportError:  # Python < 3.8 fallback – we simply won't special‑case typing constructs
        get_origin = lambda x: None  # type: ignore
        get_args = lambda x: ()      # type: ignore

    def _convert_value(raw_value: Any, annotation: Any) -> Any:
        """
        Best‑effort conversion of string/primitive values to the annotated type.

        - Primitives: cast directly (int, float, str, bool).
        - Built‑in containers: return the parsed structure (list/dict/set/tuple).
        - typing.Optional/Union: try the non‑None argument if it's primitive, otherwise
          leave as raw.
        - Other annotations: try to interpret as a dataclass/typed‑object taking **kwargs
          from a parsed dict; if that fails, fall back to the raw value.
        """
        if annotation is inspect._empty:
            return raw_value

        # Handle primitives
        if is_primitive(annotation):
            return annotation(raw_value)

        origin = get_origin(annotation)
        args = get_args(annotation)

        # Handle built‑in containers and collections from typing (list[str], dict[str, int], etc.)
        container_types = (list, dict, set, tuple)
        container_origin = origin or annotation

        if container_origin in container_types:
            parsed = raw_value
            if isinstance(raw_value, str):
                try:
                    parsed = ast.literal_eval(raw_value)
                except Exception:
                    # If we cannot parse, just return raw_value instead of dropping it
                    return raw_value

            # If the parsed value already matches the expected container type, just use it
            if isinstance(parsed, container_origin):
                return parsed

            # Last resort: try to coerce into the container without enforcing element types
            try:
                return container_origin(parsed)
            except Exception:
                return raw_value

        # Handle Optional[T] / Union[T, None] in a conservative way
        if origin is not None and origin is getattr(__import__("typing"), "Union", None):
            non_none_args = [a for a in args if a is not type(None)]  # noqa: E721
            if len(non_none_args) == 1 and is_primitive(non_none_args[0]):
                try:
                    return non_none_args[0](raw_value)
                except Exception:
                    return raw_value
            # If we cannot clearly determine the conversion, keep the raw value
            return raw_value

        # Fallback for custom / complex types: try to instantiate with **parsed
        if isinstance(raw_value, str):
            try:
                parsed = ast.literal_eval(raw_value)
            except Exception:
                return raw_value
        else:
            parsed = raw_value

        if isinstance(parsed, dict):
            try:
                return annotation(**parsed)
            except Exception:
                return raw_value

        # If nothing else worked, return the raw value
        return raw_value

    for key, value in parameters.items():
        if key not in signature.parameters:
            continue

        annotation = signature.parameters[key].annotation
        try:
            typed_parameters[key] = _convert_value(value, annotation)
        except Exception as e:
            logger.warning(f"Error converting parameter {key} to type {annotation}: {e}")
            # Preserve the original value instead of dropping the parameter
            typed_parameters[key] = value

    return typed_parameters

To fully support handling of typing constructs (e.g. Optional[int], list[str]) in a more idiomatic way, you may want to:

  1. Add a module‑level import near the top of zenoh_helper.py:
    • from typing import get_origin, get_args
      and then remove the local try/except ImportError block and directly use get_origin/get_args in parameters_type_validation.
  2. If your project already uses typing.get_origin/get_args elsewhere with a helper, you might instead want to reuse that helper to keep behavior consistent across the codebase.

if is_primitive(annotation):
typed_parameters[key] = annotation(value)
Comment on lines +269 to +270
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Boolean parameter parsing via direct bool(value) is misleading for typical string inputs.

For primitive annotations this makes bool("false") and bool("0") evaluate to True, which is likely surprising for string parameters. If booleans are provided as strings, special-case bool to parse known textual forms (e.g. "true"/"false", "1"/"0") and handle or reject anything else explicitly.

Suggested implementation:

        annotation = signature.parameters[key].annotation
        try:
            if is_primitive(annotation):
                if annotation is bool:
                    typed_parameters[key] = parse_bool(value)
                else:
                    typed_parameters[key] = annotation(value)
            else:
                parsed = ast.literal_eval(value)
                typed_parameters[key] = annotation(**parsed)
def is_primitive(value: Any) -> bool:
    return value in (int, float, str, bool)


def parse_bool(value: Any) -> bool:
    """
    Parse a value into a boolean, with explicit handling for common textual forms.

    Accepts:
    - bool: returned as-is
    - str: case-insensitive handling of "true"/"false", "1"/"0", "yes"/"no", "y"/"n", "t"/"f"
    - other types: falls back to bool(value)

    Raises:
        ValueError: if the string value does not match a known boolean representation.
    """
    if isinstance(value, bool):
        return value

    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in {"true", "t", "yes", "y", "1"}:
            return True
        if normalized in {"false", "f", "no", "n", "0"}:
            return False
        raise ValueError(f"Invalid boolean string value: {value!r}")

    return bool(value)

else:
parsed = ast.literal_eval(value)
typed_parameters[key] = annotation(**parsed)
except Exception as e:
logger.warning(f"Error converting parameter {key} to type {annotation}: {e}")
continue

return typed_parameters


def is_primitive(value: Any) -> bool:
return value in (int, float, str, bool)


def handle_json(data: Any) -> str:
try:
return json.dumps(data, default=str)
except (TypeError, ValueError) as e:
logger.error(f"Error serializing data to JSON: {data}, {e}")
return '{"error": "Serialization failed"}'