-
Notifications
You must be signed in to change notification settings - Fork 120
add zenoh support for POST, PUT and DELETE requests #3728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,72 @@ | ||||||||||||||||
| from typing import Any, Callable, Dict, List | ||||||||||||||||
|
|
||||||||||||||||
| import zenoh | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class TreeNode: | ||||||||||||||||
| def __init__(self, segment: str): | ||||||||||||||||
| self.segment = segment | ||||||||||||||||
| self.children: Dict[str, "TreeNode"] = {} | ||||||||||||||||
| self.is_valid = False | ||||||||||||||||
| self.methods: Dict[str, Callable[..., Any]] = {} | ||||||||||||||||
|
|
||||||||||||||||
| def add_child(self, child: "TreeNode") -> "TreeNode": | ||||||||||||||||
| if child.segment in self.children: | ||||||||||||||||
| return self.children[child.segment] | ||||||||||||||||
| self.children[child.segment] = child | ||||||||||||||||
| return child | ||||||||||||||||
|
Comment on lines
+14
to
+17
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| def get_methods(self) -> Dict[str, Callable[..., Any]]: | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend to use _ prefix for private properties and avoid using get prefix in functions. https://github.com/bluerobotics/software-guidelines/blob/master/guidelines/style.md |
||||||||||||||||
| return self.methods | ||||||||||||||||
|
|
||||||||||||||||
| def add_node(self, segments: List[str], method: str, func: Callable[..., Any]) -> None: | ||||||||||||||||
| if self.segment != segments[0]: | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| if len(segments) == 1: | ||||||||||||||||
| self.is_valid = True | ||||||||||||||||
| self.methods[method] = func | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| child = self.add_child(TreeNode(segments[1])) | ||||||||||||||||
| child.add_node(segments[1:], method, func) | ||||||||||||||||
|
|
||||||||||||||||
| def process_path(self, path: str, method: str, func: Callable[..., Any]) -> None: | ||||||||||||||||
| segments = path.split("/") | ||||||||||||||||
| self.add_node(segments, method, func) | ||||||||||||||||
|
|
||||||||||||||||
| def get_match(self, path: str) -> tuple[str, "TreeNode"] | None: | ||||||||||||||||
| segments = path.split("/") | ||||||||||||||||
|
|
||||||||||||||||
| if self.segment != segments[0]: | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| node = self | ||||||||||||||||
| matched_path = node.segment | ||||||||||||||||
|
|
||||||||||||||||
| for segment in segments[1:]: | ||||||||||||||||
| try: | ||||||||||||||||
| node = node.children[segment] | ||||||||||||||||
| matched_path += "/" + node.segment | ||||||||||||||||
| except KeyError: | ||||||||||||||||
| try: | ||||||||||||||||
| node = node.children["*"] | ||||||||||||||||
| matched_path += "/" + node.segment | ||||||||||||||||
| except KeyError: | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| if node.is_valid: | ||||||||||||||||
| return matched_path, node | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| def get_corresponding_method(self, sampleKind: zenoh.SampleKind) -> Callable[..., Any] | None: | ||||||||||||||||
| if sampleKind == zenoh.SampleKind.DELETE: | ||||||||||||||||
| keyword = "DELETE" | ||||||||||||||||
| else: | ||||||||||||||||
| methods = [method for method, _ in self.methods.items() if method != "DELETE"] | ||||||||||||||||
| keyword = methods[0] | ||||||||||||||||
|
Comment on lines
+62
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Guard against empty or inconsistent method mappings when resolving the handler, to avoid IndexError and surprising routing. In
Comment on lines
+63
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Clarify and tighten the mapping between SampleKind and HTTP methods to avoid arbitrary handler selection. Here you special-case Suggested implementation: if node.is_valid:
return matched_path, node
return None
def get_corresponding_method(self, sampleKind: zenoh.SampleKind) -> Callable[..., Any] | None:
"""
Map a zenoh.SampleKind to an HTTP method in a deterministic way.
1. Try an explicit SampleKind -> HTTP verb mapping.
2. If the mapped verb is not supported by this node, fall back to a
deterministic preference order among supported non-DELETE methods.
3. As a final fallback, choose the first non-DELETE method in sorted
order to keep behavior deterministic.
"""
# Explicit mapping from SampleKind to preferred HTTP method
samplekind_to_method: dict[zenoh.SampleKind, str] = {
zenoh.SampleKind.DELETE: "DELETE",
getattr(zenoh.SampleKind, "PUT", None): "PUT",
getattr(zenoh.SampleKind, "PATCH", None): "PATCH",
}
# Remove any None keys in case PUT/PATCH do not exist in this zenoh version
samplekind_to_method = {
k: v for k, v in samplekind_to_method.items() if k is not None
}
keyword: str | None = samplekind_to_method.get(sampleKind)
# If there is no direct mapping or the mapped method is not supported,
# fall back to a deterministic preference order (excluding DELETE).
if keyword is None or keyword not in self.methods:
preference_order: tuple[str, ...] = ("PUT", "POST", "PATCH", "GET")
keyword = None
for method in preference_order:
if method in self.methods and method != "DELETE":
keyword = method
break
# As a last resort, pick the first non-DELETE method in sorted order
if keyword is None:
non_delete_methods = sorted(
m for m in self.methods.keys() if m != "DELETE"
)
if non_delete_methods:
keyword = non_delete_methods[0]
if keyword is None:
return None
return self.methods.get(keyword)If from typing import Any, Callablenear the other imports. Also ensure that |
||||||||||||||||
|
|
||||||||||||||||
| try: | ||||||||||||||||
| return self.methods[keyword] | ||||||||||||||||
| except KeyError: | ||||||||||||||||
| return None | ||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,17 +1,35 @@ | ||||||
| import ast | ||||||
| import asyncio | ||||||
| import inspect | ||||||
| import json | ||||||
| import re | ||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||
| from functools import wraps | ||||||
| from typing import Any, Callable | ||||||
|
|
||||||
| import fastapi | ||||||
| import zenoh | ||||||
| from commonwealth.utils.tree import TreeNode | ||||||
| from fastapi.routing import APIRoute | ||||||
| from loguru import logger | ||||||
| from starlette.responses import StreamingResponse | ||||||
|
|
||||||
| from .Singleton import Singleton | ||||||
|
|
||||||
| PARAM_REGEX = r"{[a-zA-Z0-9_]+}" | ||||||
| PARAM_REGEX = r"{[a-zA-Z0-9_:]+}" | ||||||
| RESPONSE_PREFIX = "response/" | ||||||
|
|
||||||
|
|
||||||
| def _async_to_sync(async_func: Callable[..., Any]) -> Callable[..., Any]: | ||||||
| """ | ||||||
| Decorator to convert an async function to a sync function. | ||||||
| """ | ||||||
|
|
||||||
| @wraps(async_func) | ||||||
| def wrapper() -> None: | ||||||
| asyncio.run(async_func()) | ||||||
|
|
||||||
| return wrapper | ||||||
|
|
||||||
|
|
||||||
| class ZenohSession(metaclass=Singleton): | ||||||
|
|
@@ -66,24 +84,23 @@ def zenoh_config(self, service_name: str) -> None: | |||||
| class ZenohRouter: | ||||||
| prefix: str | ||||||
| zenoh_session: ZenohSession | ||||||
| tree: TreeNode | ||||||
|
|
||||||
| def __init__(self, service_name: str): | ||||||
| self.prefix = service_name | ||||||
| self.zenoh_session = ZenohSession(service_name) | ||||||
| self.tree = TreeNode(service_name) | ||||||
|
|
||||||
| def add_queryable(self, path: str, func: Callable[..., Any]) -> None: | ||||||
| full_path = self.prefix | ||||||
| if path: | ||||||
| full_path += f"/{path}" | ||||||
|
|
||||||
| def wrapper(query: zenoh.Query) -> None: | ||||||
| params = dict(query.parameters) # type: ignore | ||||||
|
|
||||||
| @_async_to_sync | ||||||
| async def _handle_async() -> None: | ||||||
| try: | ||||||
| response = await func(**params) | ||||||
| if response is not None: | ||||||
| query.reply(query.selector.key_expr, json.dumps(response, default=str)) | ||||||
| query.reply(query.selector.key_expr, handle_json(response)) | ||||||
| except Exception as e: | ||||||
| logger.exception(f"Error in zenoh query handler: {query.selector.key_expr}") | ||||||
| error_response = { | ||||||
|
|
@@ -92,34 +109,182 @@ async def _handle_async() -> None: | |||||
| } | ||||||
| query.reply(query.selector.key_expr, json.dumps(error_response)) | ||||||
|
|
||||||
| def run_async() -> None: | ||||||
| asyncio.run(_handle_async()) | ||||||
| self.zenoh_session.submit_to_executor(_handle_async) | ||||||
|
|
||||||
| if self.zenoh_session.session: | ||||||
| self.zenoh_session.session.declare_queryable(path, wrapper) | ||||||
|
|
||||||
| self.zenoh_session.submit_to_executor(run_async) | ||||||
| def add_subscriber(self, path: str, func: Callable[..., Any]) -> None: | ||||||
| def wrapper(sample: zenoh.Sample) -> None: | ||||||
| if not self._should_process(sample, path, func): | ||||||
| return | ||||||
|
|
||||||
| @_async_to_sync | ||||||
| async def _handle_async() -> None: | ||||||
| try: | ||||||
| parameters = get_parameters(sample, func) | ||||||
| result = await func(**parameters) | ||||||
| if result is not None: | ||||||
| await self.process_subscriber_response(result, sample) | ||||||
| except Exception as e: | ||||||
| logger.exception(f"Error in zenoh subscriber handler on {path}: {sample.kind=}, {str(e)}") | ||||||
|
|
||||||
| self.zenoh_session.submit_to_executor(_handle_async) | ||||||
|
|
||||||
| if self.zenoh_session.session: | ||||||
| self.zenoh_session.session.declare_queryable(full_path, wrapper) | ||||||
| self.zenoh_session.session.declare_subscriber(path, wrapper) | ||||||
|
|
||||||
| async def process_subscriber_response(self, result: Any, sample: zenoh.Sample) -> None: | ||||||
| if self.zenoh_session.session is None: | ||||||
| return | ||||||
|
|
||||||
| response_key = RESPONSE_PREFIX + str(sample.key_expr) | ||||||
| logger.info(f"The response for {sample.key_expr} will be published to {response_key}.") | ||||||
|
|
||||||
| with self.zenoh_session.session.declare_publisher(response_key) as publisher: | ||||||
| if isinstance(result, StreamingResponse): | ||||||
| await self._handle_streaming_response(result, publisher) | ||||||
| else: | ||||||
| publisher.put(handle_json(result)) | ||||||
|
|
||||||
| async def _handle_streaming_response(self, result: StreamingResponse, publisher: zenoh.Publisher) -> None: | ||||||
| async for chunk in result.body_iterator: | ||||||
| if isinstance(chunk, (dict, list)): | ||||||
| chunk_data = json.dumps(chunk, default=str) | ||||||
| elif isinstance(chunk, bytes): | ||||||
| chunk_data = chunk.decode("utf-8", errors="replace") | ||||||
| else: | ||||||
| chunk_data = str(chunk) | ||||||
| publisher.put(chunk_data) | ||||||
|
|
||||||
| def add_routes_to_zenoh(self, app: fastapi.FastAPI) -> None: | ||||||
| queryables = [] | ||||||
| methods = self._get_methods(app) | ||||||
| for method, path, func in methods: | ||||||
| full_path = self.get_route_path(path) | ||||||
| self.tree.process_path(full_path, method, func) | ||||||
|
|
||||||
| if method == "GET" and not is_streaming_response(func): | ||||||
| self.add_queryable(full_path, func) | ||||||
| else: | ||||||
| self.add_subscriber(full_path, func) | ||||||
|
|
||||||
| def _get_methods(self, app: fastapi.FastAPI) -> list[tuple[str, str, Callable[..., Any]]]: | ||||||
| methods = [] | ||||||
| for route in app.router.routes: | ||||||
| route_type = type(route) | ||||||
| if ( | ||||||
| isinstance(route, APIRoute) | ||||||
| and route_type.__name__ == "VersionedAPIRoute" | ||||||
| and "fastapi_versioning" in route_type.__module__ | ||||||
| and "GET" in route.methods | ||||||
| and type(route).__name__ == "VersionedAPIRoute" | ||||||
| and "fastapi_versioning" in type(route).__module__ | ||||||
| ): | ||||||
| queryables.append((clean_path(route.path), route.endpoint)) | ||||||
| method_type = next(iter(route.methods), None) | ||||||
| if method_type and method_type in ("GET", "POST", "PUT", "DELETE"): | ||||||
| methods.append((method_type, clean_path(route.path), route.endpoint)) | ||||||
| return methods | ||||||
|
|
||||||
| def get_route_path(self, path: str) -> str: | ||||||
| return self.prefix + "/" + path if path else self.prefix | ||||||
|
|
||||||
| def _should_process(self, sample: zenoh.Sample, endpoint: str, func: Callable[..., Any]) -> bool: | ||||||
| matched = self.tree.get_match(str(sample.key_expr)) | ||||||
|
|
||||||
| if matched is None: | ||||||
| return False | ||||||
|
|
||||||
| matched_path, node = matched | ||||||
| if matched_path != endpoint: | ||||||
| return False | ||||||
|
|
||||||
| found_method = node.get_corresponding_method(sample.kind) | ||||||
| if found_method is None: | ||||||
| return False | ||||||
| return found_method == func | ||||||
|
|
||||||
|
|
||||||
| for path, func in queryables: | ||||||
| self.add_queryable(path, func) | ||||||
| def is_streaming_response(func: Callable[..., Any]) -> bool: | ||||||
| signature = inspect.signature(func) | ||||||
| return_annotation = signature.return_annotation | ||||||
|
|
||||||
| if return_annotation is inspect.Signature.empty: | ||||||
| return False | ||||||
|
|
||||||
| if inspect.isclass(return_annotation): | ||||||
| return issubclass(return_annotation, StreamingResponse) | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def clean_path(path: str) -> str: | ||||||
| path = path.removeprefix("/").removesuffix("/") | ||||||
|
|
||||||
| zenoh_path = re.sub(PARAM_REGEX, "*", path) | ||||||
| zenoh_path = zenoh_path.replace("*/*", "**") | ||||||
|
|
||||||
| return zenoh_path | ||||||
|
|
||||||
|
|
||||||
| def get_parameters(sample: zenoh.Sample, func: Callable[..., Any]) -> dict[str, Any]: | ||||||
| source = None | ||||||
| if sample.kind == zenoh.SampleKind.PUT: | ||||||
| source = sample.payload | ||||||
| else: | ||||||
| source = sample.attachment | ||||||
|
|
||||||
| if source is None: | ||||||
| return {} | ||||||
|
|
||||||
| parameters = parameters_process(source.to_string()) | ||||||
| return parameters_type_validation(parameters, func) | ||||||
|
|
||||||
|
|
||||||
| def parameters_process(parameters: str) -> dict[str, str]: | ||||||
| if not parameters: | ||||||
| return {} | ||||||
|
|
||||||
| result = {} | ||||||
| for parameter in parameters.split(";"): | ||||||
| parameter = parameter.strip() | ||||||
| if not parameter: | ||||||
| continue | ||||||
|
|
||||||
| if "=" not in parameter: | ||||||
| logger.warning(f"Skipping malformated parameter (no '=' found): {parameter}") | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick (typo): Fix the typo in the warning message for malformed parameters.
Suggested change
|
||||||
| continue | ||||||
|
|
||||||
| key, value = parameter.split("=", maxsplit=1) | ||||||
| key, value = key.strip(), value.strip() | ||||||
|
|
||||||
| if key: | ||||||
| result[key] = value | ||||||
| return result | ||||||
|
|
||||||
|
|
||||||
| def parameters_type_validation(parameters: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]: | ||||||
| signature = inspect.signature(func) | ||||||
| typed_parameters = {} | ||||||
|
|
||||||
| for key, value in parameters.items(): | ||||||
| if key not in signature.parameters: | ||||||
| continue | ||||||
|
|
||||||
| annotation = signature.parameters[key].annotation | ||||||
| try: | ||||||
|
Comment on lines
+259
to
+268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Parameter type conversion logic is fragile for non-primitive annotations and may skip valid parameters. In
Consider special-casing container types (returning Suggested implementation: def parameters_type_validation(parameters: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
signature = inspect.signature(func)
typed_parameters: dict[str, Any] = {}
# Local import to avoid introducing a hard dependency if typing.get_origin is not needed elsewhere
try:
from typing import get_origin, get_args # type: ignore
except ImportError: # Python < 3.8 fallback – we simply won't special‑case typing constructs
get_origin = lambda x: None # type: ignore
get_args = lambda x: () # type: ignore
def _convert_value(raw_value: Any, annotation: Any) -> Any:
"""
Best‑effort conversion of string/primitive values to the annotated type.
- Primitives: cast directly (int, float, str, bool).
- Built‑in containers: return the parsed structure (list/dict/set/tuple).
- typing.Optional/Union: try the non‑None argument if it's primitive, otherwise
leave as raw.
- Other annotations: try to interpret as a dataclass/typed‑object taking **kwargs
from a parsed dict; if that fails, fall back to the raw value.
"""
if annotation is inspect._empty:
return raw_value
# Handle primitives
if is_primitive(annotation):
return annotation(raw_value)
origin = get_origin(annotation)
args = get_args(annotation)
# Handle built‑in containers and collections from typing (list[str], dict[str, int], etc.)
container_types = (list, dict, set, tuple)
container_origin = origin or annotation
if container_origin in container_types:
parsed = raw_value
if isinstance(raw_value, str):
try:
parsed = ast.literal_eval(raw_value)
except Exception:
# If we cannot parse, just return raw_value instead of dropping it
return raw_value
# If the parsed value already matches the expected container type, just use it
if isinstance(parsed, container_origin):
return parsed
# Last resort: try to coerce into the container without enforcing element types
try:
return container_origin(parsed)
except Exception:
return raw_value
# Handle Optional[T] / Union[T, None] in a conservative way
if origin is not None and origin is getattr(__import__("typing"), "Union", None):
non_none_args = [a for a in args if a is not type(None)] # noqa: E721
if len(non_none_args) == 1 and is_primitive(non_none_args[0]):
try:
return non_none_args[0](raw_value)
except Exception:
return raw_value
# If we cannot clearly determine the conversion, keep the raw value
return raw_value
# Fallback for custom / complex types: try to instantiate with **parsed
if isinstance(raw_value, str):
try:
parsed = ast.literal_eval(raw_value)
except Exception:
return raw_value
else:
parsed = raw_value
if isinstance(parsed, dict):
try:
return annotation(**parsed)
except Exception:
return raw_value
# If nothing else worked, return the raw value
return raw_value
for key, value in parameters.items():
if key not in signature.parameters:
continue
annotation = signature.parameters[key].annotation
try:
typed_parameters[key] = _convert_value(value, annotation)
except Exception as e:
logger.warning(f"Error converting parameter {key} to type {annotation}: {e}")
# Preserve the original value instead of dropping the parameter
typed_parameters[key] = value
return typed_parametersTo fully support handling of typing constructs (e.g.
|
||||||
| if is_primitive(annotation): | ||||||
| typed_parameters[key] = annotation(value) | ||||||
|
Comment on lines
+269
to
+270
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Boolean parameter parsing via direct For primitive annotations this makes Suggested implementation: annotation = signature.parameters[key].annotation
try:
if is_primitive(annotation):
if annotation is bool:
typed_parameters[key] = parse_bool(value)
else:
typed_parameters[key] = annotation(value)
else:
parsed = ast.literal_eval(value)
typed_parameters[key] = annotation(**parsed)def is_primitive(value: Any) -> bool:
return value in (int, float, str, bool)
def parse_bool(value: Any) -> bool:
"""
Parse a value into a boolean, with explicit handling for common textual forms.
Accepts:
- bool: returned as-is
- str: case-insensitive handling of "true"/"false", "1"/"0", "yes"/"no", "y"/"n", "t"/"f"
- other types: falls back to bool(value)
Raises:
ValueError: if the string value does not match a known boolean representation.
"""
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "t", "yes", "y", "1"}:
return True
if normalized in {"false", "f", "no", "n", "0"}:
return False
raise ValueError(f"Invalid boolean string value: {value!r}")
return bool(value) |
||||||
| else: | ||||||
| parsed = ast.literal_eval(value) | ||||||
| typed_parameters[key] = annotation(**parsed) | ||||||
| except Exception as e: | ||||||
| logger.warning(f"Error converting parameter {key} to type {annotation}: {e}") | ||||||
| continue | ||||||
|
|
||||||
| return typed_parameters | ||||||
|
|
||||||
|
|
||||||
| def is_primitive(value: Any) -> bool: | ||||||
| return value in (int, float, str, bool) | ||||||
|
|
||||||
|
|
||||||
| def handle_json(data: Any) -> str: | ||||||
| try: | ||||||
| return json.dumps(data, default=str) | ||||||
| except (TypeError, ValueError) as e: | ||||||
| logger.error(f"Error serializing data to JSON: {data}, {e}") | ||||||
| return '{"error": "Serialization failed"}' | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a description of how this class operates ? It's pretty complex just looking at the code.