From a31229eac2740bb0e1145ad412d64cd0ff33c4f3 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 19 Jul 2024 18:04:20 +0900 Subject: [PATCH 1/2] feat: Support to serialize ExceptionGroup --- src/callosum/rpc/channel.py | 4 +-- src/callosum/rpc/exceptions.py | 43 +++++++++++++++++-------------- src/callosum/rpc/message.py | 46 ++++++++++++++++++++++++++-------- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/src/callosum/rpc/channel.py b/src/callosum/rpc/channel.py index e5b250b..eedc11c 100644 --- a/src/callosum/rpc/channel.py +++ b/src/callosum/rpc/channel.py @@ -370,9 +370,9 @@ async def invoke( server_cancelled = True raise asyncio.CancelledError elif response.msgtype == RPCMessageTypes.FAILURE: - raise RPCUserError(*attrs.astuple(response.metadata)) + raise RPCUserError.from_err_metadata(response.metadata) elif response.msgtype == RPCMessageTypes.ERROR: - raise RPCInternalError(*attrs.astuple(response.metadata)) + raise RPCInternalError.from_err_metadata(response.metadata) return upper_result except (asyncio.TimeoutError, asyncio.CancelledError): # propagate cancellation to the connected peer diff --git a/src/callosum/rpc/exceptions.py b/src/callosum/rpc/exceptions.py index b8609f2..16aeb24 100644 --- a/src/callosum/rpc/exceptions.py +++ b/src/callosum/rpc/exceptions.py @@ -1,41 +1,46 @@ -from ..exceptions import CallosumError +from __future__ import annotations +from typing import TYPE_CHECKING, Self -class RPCError(CallosumError): - """ - A base exception for all RPC-specific errors. - """ +from ..exceptions import CallosumError - pass +if TYPE_CHECKING: + from .message import ErrorMetadata -class RPCUserError(RPCError): +class RPCError(CallosumError): """ - Represents an error caused in user-defined handlers. + A base exception for all RPC-specific errors. """ name: str repr: str traceback: str + exceptions: tuple - def __init__(self, name: str, repr_: str, tb: str, *args): + def __init__(self, name: str, repr_: str, tb: str, exceptions: tuple, *args): super().__init__(name, repr_, tb, *args) self.name = name self.repr = repr_ self.traceback = tb + self.exceptions = exceptions + + @classmethod + def from_err_metadata(cls, metadata: ErrorMetadata) -> Self: + return cls( + metadata.name, + metadata.repr, + metadata.traceback, + tuple(cls.from_err_metadata(err) for err in metadata.sub_errors), + ) + +class RPCUserError(RPCError): + """ + Represents an error caused in user-defined handlers. + """ class RPCInternalError(RPCError): """ Represents an error caused in Calloum's internal logic. """ - - name: str - repr: str - traceback: str - - def __init__(self, name: str, repr_: str, tb: str, *args): - super().__init__(name, tb, *args) - self.name = name - self.repr = repr_ - self.traceback = tb diff --git a/src/callosum/rpc/message.py b/src/callosum/rpc/message.py index 8fd8913..e5d2e71 100644 --- a/src/callosum/rpc/message.py +++ b/src/callosum/rpc/message.py @@ -72,6 +72,40 @@ class ErrorMetadata(Metadata): repr: str traceback: str + sub_errors: tuple[ErrorMetadata, ...] = attrs.field(factory=tuple) + + @classmethod + def decode(cls, buffer: bytes) -> Any: + if not buffer: + return None + values = munpackb(buffer) + match values: + case (name, repr, traceback, raw_sub_errors): + return cls(name, repr, traceback, tuple(cls.decode(raw_error) for raw_error in raw_sub_errors)) + case _: + return cls(*values) + + def encode(self) -> bytes: + values = [self.name, self.repr, self.traceback, [err.encode() for err in self.sub_errors]] + return mpackb(values) + + @classmethod + def from_exception(cls, exc: BaseExceptionGroup | BaseException, formatted_traceback: str) -> ErrorMetadata: + match exc: + case BaseExceptionGroup(): + return ErrorMetadata( + "ExceptionGroup", + repr(exc), + formatted_traceback, + sub_errors=tuple(cls.from_exception(sub_exc, formatted_traceback) for sub_exc in exc.exceptions) + ) + case _: + return ErrorMetadata( + type(exc).__name__, + repr(exc), + formatted_traceback, + ) + @attrs.define(frozen=True, slots=True) class NullMetadata(Metadata): @@ -150,11 +184,7 @@ def failure(cls, request): request.method, request.order_key, request.client_seq_id, - ErrorMetadata( - exc_info[0].__name__, - repr(exc_info[1]), - traceback.format_exc(), - ), + ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()), None, ) @@ -174,11 +204,7 @@ def error(cls, request): request.method, request.order_key, request.client_seq_id, - ErrorMetadata( - exc_info[0].__name__, - repr(exc_info[1]), - traceback.format_exc(), - ), + ErrorMetadata.from_exception(exc_info[1], traceback.format_exc()), None, ) From 2657f4106feb30754c5af4e7938ae75a78316211 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:11:08 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/callosum/rpc/channel.py | 2 -- src/callosum/rpc/exceptions.py | 1 + src/callosum/rpc/message.py | 23 +++++++++++++++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/callosum/rpc/channel.py b/src/callosum/rpc/channel.py index eedc11c..0f0841b 100644 --- a/src/callosum/rpc/channel.py +++ b/src/callosum/rpc/channel.py @@ -17,8 +17,6 @@ Union, ) -import attrs - from ..abc import ( AbstractChannel, AbstractDeserializer, diff --git a/src/callosum/rpc/exceptions.py b/src/callosum/rpc/exceptions.py index 16aeb24..8635f8a 100644 --- a/src/callosum/rpc/exceptions.py +++ b/src/callosum/rpc/exceptions.py @@ -40,6 +40,7 @@ class RPCUserError(RPCError): Represents an error caused in user-defined handlers. """ + class RPCInternalError(RPCError): """ Represents an error caused in Calloum's internal logic. diff --git a/src/callosum/rpc/message.py b/src/callosum/rpc/message.py index e5d2e71..d35c330 100644 --- a/src/callosum/rpc/message.py +++ b/src/callosum/rpc/message.py @@ -81,23 +81,38 @@ def decode(cls, buffer: bytes) -> Any: values = munpackb(buffer) match values: case (name, repr, traceback, raw_sub_errors): - return cls(name, repr, traceback, tuple(cls.decode(raw_error) for raw_error in raw_sub_errors)) + return cls( + name, + repr, + traceback, + tuple(cls.decode(raw_error) for raw_error in raw_sub_errors), + ) case _: return cls(*values) def encode(self) -> bytes: - values = [self.name, self.repr, self.traceback, [err.encode() for err in self.sub_errors]] + values = [ + self.name, + self.repr, + self.traceback, + [err.encode() for err in self.sub_errors], + ] return mpackb(values) @classmethod - def from_exception(cls, exc: BaseExceptionGroup | BaseException, formatted_traceback: str) -> ErrorMetadata: + def from_exception( + cls, exc: BaseExceptionGroup | BaseException, formatted_traceback: str + ) -> ErrorMetadata: match exc: case BaseExceptionGroup(): return ErrorMetadata( "ExceptionGroup", repr(exc), formatted_traceback, - sub_errors=tuple(cls.from_exception(sub_exc, formatted_traceback) for sub_exc in exc.exceptions) + sub_errors=tuple( + cls.from_exception(sub_exc, formatted_traceback) + for sub_exc in exc.exceptions + ), ) case _: return ErrorMetadata(