diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index cee74375..22b07435 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -19,7 +19,7 @@ import os import mimetypes import aiofiles -from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast +from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast, TypeVar from abc import abstractmethod, ABC from ._ffi_client import FfiClient, FfiHandle @@ -144,6 +144,12 @@ def disconnect_reason( return self._info.disconnect_reason +RpcHandler = Callable[["RpcInvocationData"], Union[Awaitable[Optional[str]], Optional[str]]] +F = TypeVar( + "F", bound=Callable[[RpcInvocationData], Union[Awaitable[Optional[str]], Optional[str]]] +) + + class LocalParticipant(Participant): """Represents the local participant in a room.""" @@ -155,9 +161,7 @@ def __init__( super().__init__(owned_info) self._room_queue = room_queue self._track_publications: dict[str, LocalTrackPublication] = {} # type: ignore - self._rpc_handlers: Dict[ - str, Callable[[RpcInvocationData], Union[Awaitable[str], str]] - ] = {} + self._rpc_handlers: Dict[str, RpcHandler] = {} @property def track_publications(self) -> Mapping[str, LocalTrackPublication]: @@ -328,8 +332,8 @@ async def perform_rpc( def register_rpc_method( self, method_name: str, - handler: Optional[Callable[[RpcInvocationData], Union[Awaitable[str], str]]] = None, - ) -> Union[None, Callable]: + handler: Optional[F] = None, + ) -> Union[F, Callable[[F], F]]: """ Establishes the participant as a receiver for calls of the specified RPC method. Can be used either as a decorator or a regular method. @@ -366,18 +370,17 @@ async def greet_handler(data: RpcInvocationData) -> str: room.local_participant.register_rpc_method('greet', greet_handler) """ - def register(handler_func): + def register(handler_func: F) -> F: self._rpc_handlers[method_name] = handler_func req = proto_ffi.FfiRequest() req.register_rpc_method.local_participant_handle = self._ffi_handle.handle req.register_rpc_method.method = method_name FfiClient.instance.request(req) + return handler_func if handler is not None: - register(handler) - return None + return register(handler) else: - # Called as a decorator return register def unregister_rpc_method(self, method: str) -> None: @@ -438,33 +441,22 @@ async def _handle_rpc_method_invocation( else: try: if asyncio.iscoroutinefunction(handler): - async_handler = cast(Callable[[RpcInvocationData], Awaitable[str]], handler) - - async def run_handler(): - try: - return await async_handler(params) - except asyncio.CancelledError: - # This will be caught by the outer try-except if it's due to timeout - raise - try: response_payload = await asyncio.wait_for( - run_handler(), timeout=response_timeout + handler(params), timeout=response_timeout ) except asyncio.TimeoutError: raise RpcError._built_in(RpcError.ErrorCode.RESPONSE_TIMEOUT) except asyncio.CancelledError: raise RpcError._built_in(RpcError.ErrorCode.RECIPIENT_DISCONNECTED) else: - sync_handler = cast(Callable[[RpcInvocationData], str], handler) - response_payload = sync_handler(params) + response_payload = cast(Optional[str], handler(params)) except RpcError as error: response_error = error - except Exception as error: + except Exception: logger.exception( f"Uncaught error returned by RPC handler for {method}. " "Returning APPLICATION_ERROR instead. " - f"Original error: {error}" ) response_error = RpcError._built_in(RpcError.ErrorCode.APPLICATION_ERROR) @@ -480,8 +472,8 @@ async def run_handler(): res = FfiClient.instance.request(req) if res.rpc_method_invocation_response.error: - message = res.rpc_method_invocation_response.error - logger.exception(f"error sending rpc method invocation response: {message}") + err = res.rpc_method_invocation_response.error + logger.error(f"error sending rpc method invocation response: {err}") async def set_metadata(self, metadata: str) -> None: """