diff --git a/src/callosum/rpc/channel.py b/src/callosum/rpc/channel.py index e5b250b..0f0841b 100644 --- a/src/callosum/rpc/channel.py +++ b/src/callosum/rpc/channel.py @@ -17,8 +17,6 @@ Union, ) -import attrs - from ..abc import ( AbstractChannel, AbstractDeserializer, @@ -370,9 +368,9 @@ async def invoke( server_cancelled = True raise asyncio.CancelledError elif response.msgtype == RPCMessageTypes.FAILURE: - raise RPCUserError(*attrs.astuple(response.metadata)) + raise RPCUserError.from_err_metadata(response.metadata) elif response.msgtype == RPCMessageTypes.ERROR: - raise RPCInternalError(*attrs.astuple(response.metadata)) + raise RPCInternalError.from_err_metadata(response.metadata) return upper_result except (asyncio.TimeoutError, asyncio.CancelledError): # propagate cancellation to the connected peer diff --git a/src/callosum/rpc/exceptions.py b/src/callosum/rpc/exceptions.py index b8609f2..8635f8a 100644 --- a/src/callosum/rpc/exceptions.py +++ b/src/callosum/rpc/exceptions.py @@ -1,41 +1,47 @@ -from ..exceptions import CallosumError +from __future__ import annotations +from typing import TYPE_CHECKING, Self -class RPCError(CallosumError): - """ - A base exception for all RPC-specific errors. - """ +from ..exceptions import CallosumError - pass +if TYPE_CHECKING: + from .message import ErrorMetadata -class RPCUserError(RPCError): +class RPCError(CallosumError): """ - Represents an error caused in user-defined handlers. + A base exception for all RPC-specific errors. """ name: str repr: str traceback: str + exceptions: tuple - def __init__(self, name: str, repr_: str, tb: str, *args): + def __init__(self, name: str, repr_: str, tb: str, exceptions: tuple, *args): super().__init__(name, repr_, tb, *args) self.name = name self.repr = repr_ self.traceback = tb + self.exceptions = exceptions + @classmethod + def from_err_metadata(cls, metadata: ErrorMetadata) -> Self: + return cls( + metadata.name, + metadata.repr, + metadata.traceback, + tuple(cls.from_err_metadata(err) for err in metadata.sub_errors), + ) -class RPCInternalError(RPCError): + +class RPCUserError(RPCError): """ - Represents an error caused in Calloum's internal logic. + Represents an error caused in user-defined handlers. """ - name: str - repr: str - traceback: str - def __init__(self, name: str, repr_: str, tb: str, *args): - super().__init__(name, tb, *args) - self.name = name - self.repr = repr_ - self.traceback = tb +class RPCInternalError(RPCError): + """ + Represents an error caused in Calloum's internal logic. + """ diff --git a/src/callosum/rpc/message.py b/src/callosum/rpc/message.py index 8fd8913..d35c330 100644 --- a/src/callosum/rpc/message.py +++ b/src/callosum/rpc/message.py @@ -72,6 +72,55 @@ class ErrorMetadata(Metadata): repr: str traceback: str + sub_errors: tuple[ErrorMetadata, ...] = attrs.field(factory=tuple) + + @classmethod + def decode(cls, buffer: bytes) -> Any: + if not buffer: + return None + values = munpackb(buffer) + match values: + case (name, repr, traceback, raw_sub_errors): + return cls( + name, + repr, + traceback, + tuple(cls.decode(raw_error) for raw_error in raw_sub_errors), + ) + case _: + return cls(*values) + + def encode(self) -> bytes: + values = [ + self.name, + self.repr, + self.traceback, + [err.encode() for err in self.sub_errors], + ] + return mpackb(values) + + @classmethod + def from_exception( + cls, exc: BaseExceptionGroup | BaseException, formatted_traceback: str + ) -> ErrorMetadata: + match exc: + case BaseExceptionGroup(): + return ErrorMetadata( + "ExceptionGroup", + repr(exc), + formatted_traceback, + sub_errors=tuple( + cls.from_exception(sub_exc, formatted_traceback) + for sub_exc in exc.exceptions + ), + ) + case _: + return ErrorMetadata( + type(exc).__name__, + repr(exc), + formatted_traceback, + ) + @attrs.define(frozen=True, slots=True) class NullMetadata(Metadata): @@ -150,11 +199,7 @@ def failure(cls, request): request.method, request.order_key, request.client_seq_id, - ErrorMetadata( - exc_info[0].__name__, - repr(exc_info[1]), - traceback.format_exc(), - ), + ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()), None, ) @@ -174,11 +219,7 @@ def error(cls, request): request.method, request.order_key, request.client_seq_id, - ErrorMetadata( - exc_info[0].__name__, - repr(exc_info[1]), - traceback.format_exc(), - ), + ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()), None, )