diff --git a/docs/api/chat.mdx b/docs/api/chat.mdx index aad1261..3d1469f 100644 --- a/docs/api/chat.mdx +++ b/docs/api/chat.mdx @@ -1910,11 +1910,6 @@ def map( ~~~ """ for callback in callbacks: - if not asyncio.iscoroutinefunction(callback): - raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", - ) - if allow_duplicates: continue @@ -2661,11 +2656,6 @@ def then( ~~~ """ for callback in callbacks: - if not asyncio.iscoroutinefunction(callback): - raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", - ) - if allow_duplicates: continue diff --git a/rigging/chat.py b/rigging/chat.py index 1d944f1..5e5c242 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -618,7 +618,7 @@ def __call__( self, chat: Chat, /, - ) -> t.Awaitable[Chat | None]: ... + ) -> t.Awaitable[Chat | None] | Chat | None: ... @runtime_checkable @@ -642,7 +642,7 @@ def __call__( self, chats: list[Chat], /, - ) -> t.Awaitable[list[Chat]]: ... + ) -> t.Awaitable[list[Chat]] | list[Chat]: ... @runtime_checkable @@ -773,7 +773,9 @@ async def traced_watch_callback(chats: list[Chat]) -> None: chat_count=len(chats), chat_ids=[str(c.uuid) for c in chats], ): - await callback(chats) + result = callback(chats) + if inspect.isawaitable(result): + await result return traced_watch_callback @@ -1100,11 +1102,6 @@ async def process(chat: Chat) -> Chat | None: ``` """ for callback in callbacks: - if not asyncio.iscoroutinefunction(callback): - raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", - ) - if allow_duplicates: continue @@ -1147,11 +1144,6 @@ async def process(chats: list[Chat]) -> list[Chat]: ``` """ for callback in callbacks: - if not asyncio.iscoroutinefunction(callback): - raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", - ) - if allow_duplicates: continue @@ -1565,9 +1557,8 @@ async def complete() -> None: exit_stack.push_async_callback(complete) result = callback(state.chat) - if inspect.isawaitable(result): - result = await result # type: ignore [assignment] + result = await result if result is None or isinstance(result, Chat): state.chat = result or state.chat diff --git a/tests/test_message_slicing.py b/tests/test_message_slicing.py index fb5718b..9dbff45 100644 --- a/tests/test_message_slicing.py +++ b/tests/test_message_slicing.py @@ -953,7 +953,13 @@ def test_slice_with_empty_string_target() -> None: """Test marking slice with empty string target.""" message = Message("assistant", "Some content here") - slice_obj = message.mark_slice("") + # Expect a "Empty string target provided" warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + slice_obj = message.mark_slice("") + assert len(w) == 1 + assert issubclass(w[-1].category, MessageWarning) + assert "Empty string target provided" in str(w[-1].message) # Empty string should not create a valid slice assert slice_obj is None