diff --git a/AGENTS.md b/AGENTS.md index 537cabfcf..d15c51914 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -123,7 +123,7 @@ slack_bolt/adapter/flask/handler.py # sync-only (no async Flask a ### AI Agents & Assistants -`BoltAgent` (`slack_bolt/agent/`) provides `chat_stream()`, `set_status()`, and `set_suggested_prompts()` for AI-powered agents. `Assistant` middleware (`slack_bolt/middleware/assistant/`) handles assistant thread events. +`SayStream` (`slack_bolt/context/say_stream/`) provides `chat_stream()` for AI-powered agents (experimental). `Assistant` middleware (`slack_bolt/middleware/assistant/`) handles assistant thread events. `set_status` and `set_suggested_prompts` are available as context utilities for assistant events. ## Key Development Patterns diff --git a/docs/english/experiments.md b/docs/english/experiments.md index 681c8cbc6..c97f15496 100644 --- a/docs/english/experiments.md +++ b/docs/english/experiments.md @@ -1,34 +1,32 @@ # Experiments -Bolt for Python includes experimental features still under active development. These features may be fleeting, may not be perfectly polished, and should be thought of as available for use "at your own risk." +Bolt for Python includes experimental features still under active development. These features may be fleeting, may not be perfectly polished, and should be thought of as available for use "at your own risk." Experimental features are categorized as `semver:patch` until the experimental status is removed. We love feedback from our community, so we encourage you to explore and interact with the [GitHub repo](https://github.com/slackapi/bolt-python). Contributions, bug reports, and any feedback are all helpful; let us nurture the Slack CLI together to help make building Slack apps more pleasant for everyone. ## Available experiments -* [Agent listener argument](#agent) +* [say_stream listener argument](#say-stream) -## Agent listener argument {#agent} +## say_stream listener argument {#say-stream} -The `agent: BoltAgent` listener argument provides access to AI agent-related features. +The `say_stream: SayStream` listener argument provides access to streaming chat for AI agents. -The `BoltAgent` and `AsyncBoltAgent` classes offer a `chat_stream()` method that comes pre-configured with event context defaults: `channel_id`, `thread_ts`, `team_id`, and `user_id` fields. - -The listener argument is wired into the Bolt `kwargs` injection system, so listeners can declare it as a parameter or access it via the `context.agent` property. +`SayStream` and `AsyncSayStream` are callable context utilities pre-configured with event context defaults: `channel_id`, `thread_ts`, `team_id`, and `user_id`. ### Example ```python -from slack_bolt import BoltAgent +from slack_bolt import SayStream @app.event("app_mention") -def handle_mention(agent: BoltAgent): - stream = agent.chat_stream() +def handle_mention(say_stream: SayStream): + stream = say_stream() stream.append(markdown_text="Hello!") stream.stop() ``` ### Limitations -The `chat_stream()` method currently only works when the `thread_ts` field is available in the event context (DMs and threaded replies). Top-level channel messages do not have a `thread_ts` field, and the `ts` field is not yet provided to `BoltAgent`. \ No newline at end of file +`say_stream()` requires either `thread_ts` or `event.ts` in the event context. It works in DMs, threaded replies, and top-level messages with a `ts` field. diff --git a/slack_bolt/__init__.py b/slack_bolt/__init__.py index 4e43252fd..e834baa40 100644 --- a/slack_bolt/__init__.py +++ b/slack_bolt/__init__.py @@ -21,7 +21,7 @@ from .response import BoltResponse # AI Agents & Assistants -from .agent import BoltAgent +from .context.say_stream import SayStream from .middleware.assistant.assistant import ( Assistant, ) @@ -47,7 +47,7 @@ "CustomListenerMatcher", "BoltRequest", "BoltResponse", - "BoltAgent", + "SayStream", "Assistant", "AssistantThreadContext", "AssistantThreadContextStore", diff --git a/slack_bolt/agent/__init__.py b/slack_bolt/agent/__init__.py deleted file mode 100644 index 4d751f27f..000000000 --- a/slack_bolt/agent/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .agent import BoltAgent - -__all__ = [ - "BoltAgent", -] diff --git a/slack_bolt/agent/agent.py b/slack_bolt/agent/agent.py deleted file mode 100644 index 523b0e33c..000000000 --- a/slack_bolt/agent/agent.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import Dict, List, Optional, Sequence, Union - -from slack_sdk import WebClient -from slack_sdk.web import SlackResponse -from slack_sdk.web.chat_stream import ChatStream - - -class BoltAgent: - """Agent listener argument for building AI-powered Slack agents. - - Experimental: - This API is experimental and may change in future releases. - - @app.event("app_mention") - def handle_mention(agent): - stream = agent.chat_stream() - stream.append(markdown_text="Hello!") - stream.stop() - """ - - def __init__( - self, - *, - client: WebClient, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - ts: Optional[str] = None, - team_id: Optional[str] = None, - user_id: Optional[str] = None, - ): - self._client = client - self._channel_id = channel_id - self._thread_ts = thread_ts - self._ts = ts - self._team_id = team_id - self._user_id = user_id - - def chat_stream( - self, - *, - channel: Optional[str] = None, - thread_ts: Optional[str] = None, - recipient_team_id: Optional[str] = None, - recipient_user_id: Optional[str] = None, - **kwargs, - ) -> ChatStream: - """Creates a ChatStream with defaults from event context. - - Each call creates a new instance. Create multiple for parallel streams. - - Args: - channel: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - recipient_team_id: Team ID of the recipient. Defaults to the team from the event context. - recipient_user_id: User ID of the recipient. Defaults to the user from the event context. - **kwargs: Additional arguments passed to ``WebClient.chat_stream()``. - - Returns: - A new ``ChatStream`` instance. - """ - provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] - if provided and len(provided) < 4: - raise ValueError( - "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" - ) - # Argument validation is delegated to chat_stream() and the API - return self._client.chat_stream( - channel=channel or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - recipient_team_id=recipient_team_id or self._team_id, - recipient_user_id=recipient_user_id or self._user_id, - **kwargs, - ) - - def set_status( - self, - *, - status: str, - loading_messages: Optional[List[str]] = None, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - **kwargs, - ) -> SlackResponse: - """Sets the status of an assistant thread. - - Args: - status: The status text to display. - loading_messages: Optional list of loading messages to cycle through. - channel_id: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - **kwargs: Additional arguments passed to ``WebClient.assistant_threads_setStatus()``. - - Returns: - ``SlackResponse`` from the API call. - """ - return self._client.assistant_threads_setStatus( - channel_id=channel_id or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - status=status, - loading_messages=loading_messages, - **kwargs, - ) - - def set_suggested_prompts( - self, - *, - prompts: Sequence[Union[str, Dict[str, str]]], - title: Optional[str] = None, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - **kwargs, - ) -> SlackResponse: - """Sets suggested prompts for an assistant thread. - - Args: - prompts: A sequence of prompts. Each prompt can be either a string - (used as both title and message) or a dict with 'title' and 'message' keys. - title: Optional title for the suggested prompts section. - channel_id: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - **kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``. - - Returns: - ``SlackResponse`` from the API call. - """ - prompts_arg: List[Dict[str, str]] = [] - for prompt in prompts: - if isinstance(prompt, str): - prompts_arg.append({"title": prompt, "message": prompt}) - else: - prompts_arg.append(prompt) - - return self._client.assistant_threads_setSuggestedPrompts( - channel_id=channel_id or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - prompts=prompts_arg, - title=title, - **kwargs, - ) diff --git a/slack_bolt/agent/async_agent.py b/slack_bolt/agent/async_agent.py deleted file mode 100644 index da4ec6c0a..000000000 --- a/slack_bolt/agent/async_agent.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Dict, List, Optional, Sequence, Union - -from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient -from slack_sdk.web.async_chat_stream import AsyncChatStream - - -class AsyncBoltAgent: - """Async agent listener argument for building AI-powered Slack agents. - - Experimental: - This API is experimental and may change in future releases. - - @app.event("app_mention") - async def handle_mention(agent): - stream = await agent.chat_stream() - await stream.append(markdown_text="Hello!") - await stream.stop() - """ - - def __init__( - self, - *, - client: AsyncWebClient, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - ts: Optional[str] = None, - team_id: Optional[str] = None, - user_id: Optional[str] = None, - ): - self._client = client - self._channel_id = channel_id - self._thread_ts = thread_ts - self._ts = ts - self._team_id = team_id - self._user_id = user_id - - async def chat_stream( - self, - *, - channel: Optional[str] = None, - thread_ts: Optional[str] = None, - recipient_team_id: Optional[str] = None, - recipient_user_id: Optional[str] = None, - **kwargs, - ) -> AsyncChatStream: - """Creates an AsyncChatStream with defaults from event context. - - Each call creates a new instance. Create multiple for parallel streams. - - Args: - channel: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - recipient_team_id: Team ID of the recipient. Defaults to the team from the event context. - recipient_user_id: User ID of the recipient. Defaults to the user from the event context. - **kwargs: Additional arguments passed to ``AsyncWebClient.chat_stream()``. - - Returns: - A new ``AsyncChatStream`` instance. - """ - provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] - if provided and len(provided) < 4: - raise ValueError( - "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" - ) - # Argument validation is delegated to chat_stream() and the API - return await self._client.chat_stream( - channel=channel or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - recipient_team_id=recipient_team_id or self._team_id, - recipient_user_id=recipient_user_id or self._user_id, - **kwargs, - ) - - async def set_status( - self, - *, - status: str, - loading_messages: Optional[List[str]] = None, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - **kwargs, - ) -> AsyncSlackResponse: - """Sets the status of an assistant thread. - - Args: - status: The status text to display. - loading_messages: Optional list of loading messages to cycle through. - channel_id: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - **kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setStatus()``. - - Returns: - ``AsyncSlackResponse`` from the API call. - """ - return await self._client.assistant_threads_setStatus( - channel_id=channel_id or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - status=status, - loading_messages=loading_messages, - **kwargs, - ) - - async def set_suggested_prompts( - self, - *, - prompts: Sequence[Union[str, Dict[str, str]]], - title: Optional[str] = None, - channel_id: Optional[str] = None, - thread_ts: Optional[str] = None, - **kwargs, - ) -> AsyncSlackResponse: - """Sets suggested prompts for an assistant thread. - - Args: - prompts: A sequence of prompts. Each prompt can be either a string - (used as both title and message) or a dict with 'title' and 'message' keys. - title: Optional title for the suggested prompts section. - channel_id: Channel ID. Defaults to the channel from the event context. - thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. - **kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``. - - Returns: - ``AsyncSlackResponse`` from the API call. - """ - prompts_arg: List[Dict[str, str]] = [] - for prompt in prompts: - if isinstance(prompt, str): - prompts_arg.append({"title": prompt, "message": prompt}) - else: - prompts_arg.append(prompt) - - return await self._client.assistant_threads_setSuggestedPrompts( - channel_id=channel_id or self._channel_id, # type: ignore[arg-type] - thread_ts=thread_ts or self._thread_ts or self._ts, # type: ignore[arg-type] - prompts=prompts_arg, - title=title, - **kwargs, - ) diff --git a/slack_bolt/app/app.py b/slack_bolt/app/app.py index 5a7f32917..6503c62e7 100644 --- a/slack_bolt/app/app.py +++ b/slack_bolt/app/app.py @@ -70,6 +70,7 @@ IgnoringSelfEvents, CustomMiddleware, AttachingFunctionToken, + AttachingAgentKwargs, ) from slack_bolt.middleware.assistant import Assistant from slack_bolt.middleware.message_listener_matches import MessageListenerMatches @@ -128,6 +129,7 @@ def __init__( ssl_check_enabled: bool = True, url_verification_enabled: bool = True, attaching_function_token_enabled: bool = True, + attaching_agent_kwargs_enabled: bool = True, # for the OAuth flow oauth_settings: Optional[OAuthSettings] = None, oauth_flow: Optional[OAuthFlow] = None, @@ -357,6 +359,7 @@ def message_hello(message, say): listener_executor = ThreadPoolExecutor(max_workers=5) self._assistant_thread_context_store = assistant_thread_context_store + self._attaching_agent_kwargs_enabled = attaching_agent_kwargs_enabled self._process_before_response = process_before_response self._listener_runner = ThreadListenerRunner( @@ -841,10 +844,14 @@ def ask_for_introduction(event, say): middleware: A list of lister middleware functions. Only when all the middleware call `next()` method, the listener function can be invoked. """ + matchers = list(matchers) if matchers else [] + middleware = list(middleware) if middleware else [] def __call__(*args, **kwargs): functions = self._to_listener_functions(kwargs) if kwargs else list(args) primary_matcher = builtin_matchers.event(event, base_logger=self._base_logger) + if self._attaching_agent_kwargs_enabled: + middleware.insert(0, AttachingAgentKwargs()) return self._register_listener(list(functions), primary_matcher, matchers, middleware, True) return __call__ @@ -902,6 +909,8 @@ def __call__(*args, **kwargs): primary_matcher = builtin_matchers.message_event( keyword=keyword, constraints=constraints, base_logger=self._base_logger ) + if self._attaching_agent_kwargs_enabled: + middleware.insert(0, AttachingAgentKwargs()) middleware.insert(0, MessageListenerMatches(keyword)) return self._register_listener(list(functions), primary_matcher, matchers, middleware, True) diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 39f3c3c0e..64bf576e4 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -88,6 +88,7 @@ AsyncIgnoringSelfEvents, AsyncUrlVerification, AsyncAttachingFunctionToken, + AsyncAttachingAgentKwargs, ) from slack_bolt.middleware.async_custom_middleware import ( AsyncMiddleware, @@ -136,6 +137,7 @@ def __init__( ssl_check_enabled: bool = True, url_verification_enabled: bool = True, attaching_function_token_enabled: bool = True, + attaching_agent_kwargs_enabled: bool = True, # for the OAuth flow oauth_settings: Optional[AsyncOAuthSettings] = None, oauth_flow: Optional[AsyncOAuthFlow] = None, @@ -363,6 +365,7 @@ async def message_hello(message, say): # async function self._async_listeners: List[AsyncListener] = [] self._assistant_thread_context_store = assistant_thread_context_store + self._attaching_agent_kwargs_enabled = attaching_agent_kwargs_enabled self._process_before_response = process_before_response self._async_listener_runner = AsyncioListenerRunner( @@ -866,10 +869,14 @@ async def ask_for_introduction(event, say): middleware: A list of lister middleware functions. Only when all the middleware call `next()` method, the listener function can be invoked. """ + matchers = list(matchers) if matchers else [] + middleware = list(middleware) if middleware else [] def __call__(*args, **kwargs): functions = self._to_listener_functions(kwargs) if kwargs else list(args) primary_matcher = builtin_matchers.event(event, True, base_logger=self._base_logger) + if self._attaching_agent_kwargs_enabled: + middleware.insert(0, AsyncAttachingAgentKwargs()) return self._register_listener(list(functions), primary_matcher, matchers, middleware, True) return __call__ @@ -930,6 +937,8 @@ def __call__(*args, **kwargs): asyncio=True, base_logger=self._base_logger, ) + if self._attaching_agent_kwargs_enabled: + middleware.insert(0, AsyncAttachingAgentKwargs()) middleware.insert(0, AsyncMessageListenerMatches(keyword)) return self._register_listener(list(functions), primary_matcher, matchers, middleware, True) diff --git a/slack_bolt/context/async_context.py b/slack_bolt/context/async_context.py index 47eb4744e..54d7b29e7 100644 --- a/slack_bolt/context/async_context.py +++ b/slack_bolt/context/async_context.py @@ -10,6 +10,7 @@ from slack_bolt.context.get_thread_context.async_get_thread_context import AsyncGetThreadContext from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.say.async_say import AsyncSay +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.context.set_status.async_set_status import AsyncSetStatus from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts from slack_bolt.context.set_title.async_set_title import AsyncSetTitle @@ -187,6 +188,10 @@ async def handle_button_clicks(context): self["fail"] = AsyncFail(client=self.client, function_execution_id=self.function_execution_id) return self["fail"] + @property + def say_stream(self) -> Optional[AsyncSayStream]: + return self.get("say_stream") + @property def set_title(self) -> Optional[AsyncSetTitle]: return self.get("set_title") diff --git a/slack_bolt/context/base_context.py b/slack_bolt/context/base_context.py index 843d5ef60..502febcb8 100644 --- a/slack_bolt/context/base_context.py +++ b/slack_bolt/context/base_context.py @@ -38,6 +38,7 @@ class BaseContext(dict): "set_status", "set_title", "set_suggested_prompts", + "say_stream", ] # Note that these items are not copyable, so when you add new items to this list, # you must modify ThreadListenerRunner/AsyncioListenerRunner's _build_lazy_request method to pass the values. diff --git a/slack_bolt/context/context.py b/slack_bolt/context/context.py index 31edf2891..d3267abdc 100644 --- a/slack_bolt/context/context.py +++ b/slack_bolt/context/context.py @@ -10,6 +10,7 @@ from slack_bolt.context.respond import Respond from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.say import Say +from slack_bolt.context.say_stream import SayStream from slack_bolt.context.set_status import SetStatus from slack_bolt.context.set_suggested_prompts import SetSuggestedPrompts from slack_bolt.context.set_title import SetTitle @@ -188,6 +189,10 @@ def handle_button_clicks(context): self["fail"] = Fail(client=self.client, function_execution_id=self.function_execution_id) return self["fail"] + @property + def say_stream(self) -> Optional[SayStream]: + return self.get("say_stream") + @property def set_title(self) -> Optional[SetTitle]: return self.get("set_title") diff --git a/slack_bolt/context/say_stream/__init__.py b/slack_bolt/context/say_stream/__init__.py new file mode 100644 index 000000000..86db7b1cc --- /dev/null +++ b/slack_bolt/context/say_stream/__init__.py @@ -0,0 +1,6 @@ +# Don't add async module imports here +from .say_stream import SayStream + +__all__ = [ + "SayStream", +] diff --git a/slack_bolt/context/say_stream/async_say_stream.py b/slack_bolt/context/say_stream/async_say_stream.py new file mode 100644 index 000000000..7c115179b --- /dev/null +++ b/slack_bolt/context/say_stream/async_say_stream.py @@ -0,0 +1,55 @@ +from typing import Optional + +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_chat_stream import AsyncChatStream + + +class AsyncSayStream: + client: AsyncWebClient + channel_id: Optional[str] + thread_ts: Optional[str] + team_id: Optional[str] + user_id: Optional[str] + + def __init__( + self, + client: AsyncWebClient, + channel_id: Optional[str], + thread_ts: Optional[str], + team_id: Optional[str], + user_id: Optional[str], + ): + self.client = client + self.channel_id = channel_id + self.thread_ts = thread_ts + self.team_id = team_id + self.user_id = user_id + + async def __call__( + self, + *, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + **kwargs, + ) -> AsyncChatStream: + provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] + if provided and len(provided) < 4: + raise ValueError( + "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" + ) + channel = channel or self.channel_id + thread_ts = thread_ts or self.thread_ts + if channel is None: + raise ValueError("say_stream is unsupported here as there is no channel_id") + if thread_ts is None: + raise ValueError("say_stream is unsupported here as there is no thread_ts") + + return await self.client.chat_stream( + channel=channel, + thread_ts=thread_ts, + recipient_team_id=recipient_team_id or self.team_id, + recipient_user_id=recipient_user_id or self.user_id, + **kwargs, + ) diff --git a/slack_bolt/context/say_stream/say_stream.py b/slack_bolt/context/say_stream/say_stream.py new file mode 100644 index 000000000..20cbb03f5 --- /dev/null +++ b/slack_bolt/context/say_stream/say_stream.py @@ -0,0 +1,55 @@ +from typing import Optional + +from slack_sdk import WebClient +from slack_sdk.web.chat_stream import ChatStream + + +class SayStream: + client: WebClient + channel_id: Optional[str] + thread_ts: Optional[str] + team_id: Optional[str] + user_id: Optional[str] + + def __init__( + self, + client: WebClient, + channel_id: Optional[str], + thread_ts: Optional[str], + team_id: Optional[str], + user_id: Optional[str], + ): + self.client = client + self.channel_id = channel_id + self.thread_ts = thread_ts + self.team_id = team_id + self.user_id = user_id + + def __call__( + self, + *, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + **kwargs, + ) -> ChatStream: + provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] + if provided and len(provided) < 4: + raise ValueError( + "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" + ) + channel = channel or self.channel_id + thread_ts = thread_ts or self.thread_ts + if channel is None: + raise ValueError("say_stream is unsupported here as there is no channel_id") + if thread_ts is None: + raise ValueError("say_stream is unsupported here as there is no thread_ts") + + return self.client.chat_stream( + channel=channel, + thread_ts=thread_ts, + recipient_team_id=recipient_team_id or self.team_id, + recipient_user_id=recipient_user_id or self.user_id, + **kwargs, + ) diff --git a/slack_bolt/kwargs_injection/args.py b/slack_bolt/kwargs_injection/args.py index 113e39c08..4cd70176d 100644 --- a/slack_bolt/kwargs_injection/args.py +++ b/slack_bolt/kwargs_injection/args.py @@ -8,9 +8,9 @@ from slack_bolt.context.fail import Fail from slack_bolt.context.get_thread_context.get_thread_context import GetThreadContext from slack_bolt.context.respond import Respond -from slack_bolt.agent.agent import BoltAgent from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.say import Say +from slack_bolt.context.say_stream import SayStream from slack_bolt.context.set_status import SetStatus from slack_bolt.context.set_suggested_prompts import SetSuggestedPrompts from slack_bolt.context.set_title import SetTitle @@ -103,8 +103,8 @@ def handle_buttons(args): """`get_thread_context()` utility function for AI Agents & Assistants""" save_thread_context: Optional[SaveThreadContext] """`save_thread_context()` utility function for AI Agents & Assistants""" - agent: Optional[BoltAgent] - """`agent` listener argument for AI Agents & Assistants""" + say_stream: Optional[SayStream] + """`say_stream()` utility function for AI Agents & Assistants""" # middleware next: Callable[[], None] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -138,7 +138,7 @@ def __init__( set_suggested_prompts: Optional[SetSuggestedPrompts] = None, get_thread_context: Optional[GetThreadContext] = None, save_thread_context: Optional[SaveThreadContext] = None, - agent: Optional[BoltAgent] = None, + say_stream: Optional[SayStream] = None, # As this method is not supposed to be invoked by bolt-python users, # the naming conflict with the built-in one affects # only the internals of this method @@ -172,7 +172,7 @@ def __init__( self.set_suggested_prompts = set_suggested_prompts self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context - self.agent = agent + self.say_stream = say_stream self.next: Callable[[], None] = next self.next_: Callable[[], None] = next diff --git a/slack_bolt/kwargs_injection/async_args.py b/slack_bolt/kwargs_injection/async_args.py index 1f1dde024..2217cfe9f 100644 --- a/slack_bolt/kwargs_injection/async_args.py +++ b/slack_bolt/kwargs_injection/async_args.py @@ -1,7 +1,6 @@ from logging import Logger from typing import Callable, Awaitable, Dict, Any, Optional -from slack_bolt.agent.async_agent import AsyncBoltAgent from slack_bolt.context.ack.async_ack import AsyncAck from slack_bolt.context.async_context import AsyncBoltContext from slack_bolt.context.complete.async_complete import AsyncComplete @@ -10,6 +9,7 @@ from slack_bolt.context.get_thread_context.async_get_thread_context import AsyncGetThreadContext from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext from slack_bolt.context.say.async_say import AsyncSay +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.context.set_status.async_set_status import AsyncSetStatus from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts from slack_bolt.context.set_title.async_set_title import AsyncSetTitle @@ -102,8 +102,8 @@ async def handle_buttons(args): """`get_thread_context()` utility function for AI Agents & Assistants""" save_thread_context: Optional[AsyncSaveThreadContext] """`save_thread_context()` utility function for AI Agents & Assistants""" - agent: Optional[AsyncBoltAgent] - """`agent` listener argument for AI Agents & Assistants""" + say_stream: Optional[AsyncSayStream] + """`say_stream()` utility function for AI Agents & Assistants""" # middleware next: Callable[[], Awaitable[None]] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -137,7 +137,7 @@ def __init__( set_suggested_prompts: Optional[AsyncSetSuggestedPrompts] = None, get_thread_context: Optional[AsyncGetThreadContext] = None, save_thread_context: Optional[AsyncSaveThreadContext] = None, - agent: Optional[AsyncBoltAgent] = None, + say_stream: Optional[AsyncSayStream] = None, next: Callable[[], Awaitable[None]], **kwargs, # noqa ): @@ -168,7 +168,7 @@ def __init__( self.set_suggested_prompts = set_suggested_prompts self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context - self.agent = agent + self.say_stream = say_stream self.next: Callable[[], Awaitable[None]] = next self.next_: Callable[[], Awaitable[None]] = next diff --git a/slack_bolt/kwargs_injection/async_utils.py b/slack_bolt/kwargs_injection/async_utils.py index aa84b2d11..c9c255b19 100644 --- a/slack_bolt/kwargs_injection/async_utils.py +++ b/slack_bolt/kwargs_injection/async_utils.py @@ -59,6 +59,7 @@ def build_async_required_kwargs( "set_title": request.context.set_title, "set_suggested_prompts": request.context.set_suggested_prompts, "get_thread_context": request.context.get_thread_context, + "say_stream": request.context.say_stream, "save_thread_context": request.context.save_thread_context, # middleware "next": next_func, @@ -85,26 +86,6 @@ def build_async_required_kwargs( if k not in all_available_args: all_available_args[k] = v - # Defer agent creation to avoid constructing AsyncBoltAgent on every request - if "agent" in required_arg_names: - from slack_bolt.agent.async_agent import AsyncBoltAgent - - event = request.body.get("event", {}) - - all_available_args["agent"] = AsyncBoltAgent( - client=request.context.client, - channel_id=request.context.channel_id, - thread_ts=request.context.thread_ts or event.get("thread_ts"), - ts=event.get("ts"), - team_id=request.context.team_id, - user_id=request.context.user_id, - ) - warnings.warn( - "The agent listener argument is experimental and may change in future versions.", - category=ExperimentalWarning, - stacklevel=2, # Point to the caller, not this internal helper - ) - if len(required_arg_names) > 0: # To support instance/class methods in a class for listeners/middleware, # check if the first argument is either self or cls @@ -119,6 +100,13 @@ def build_async_required_kwargs( # We are sure that we should skip manipulating this arg required_arg_names.pop(0) + if "say_stream" in required_arg_names: + warnings.warn( + "The say_stream listener argument is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, # Point to the caller, not this internal helper + ) + kwargs: Dict[str, Any] = {k: v for k, v in all_available_args.items() if k in required_arg_names} found_arg_names = kwargs.keys() for name in required_arg_names: diff --git a/slack_bolt/kwargs_injection/utils.py b/slack_bolt/kwargs_injection/utils.py index 5cd410a07..45185b124 100644 --- a/slack_bolt/kwargs_injection/utils.py +++ b/slack_bolt/kwargs_injection/utils.py @@ -58,6 +58,7 @@ def build_required_kwargs( "set_status": request.context.set_status, "set_title": request.context.set_title, "set_suggested_prompts": request.context.set_suggested_prompts, + "say_stream": request.context.say_stream, "save_thread_context": request.context.save_thread_context, # middleware "next": next_func, @@ -84,26 +85,6 @@ def build_required_kwargs( if k not in all_available_args: all_available_args[k] = v - # Defer agent creation to avoid constructing BoltAgent on every request - if "agent" in required_arg_names: - from slack_bolt.agent.agent import BoltAgent - - event = request.body.get("event", {}) - - all_available_args["agent"] = BoltAgent( - client=request.context.client, - channel_id=request.context.channel_id, - thread_ts=request.context.thread_ts or event.get("thread_ts"), - ts=event.get("ts"), - team_id=request.context.team_id, - user_id=request.context.user_id, - ) - warnings.warn( - "The agent listener argument is experimental and may change in future versions.", - category=ExperimentalWarning, - stacklevel=2, # Point to the caller, not this internal helper - ) - if len(required_arg_names) > 0: # To support instance/class methods in a class for listeners/middleware, # check if the first argument is either self or cls @@ -118,6 +99,13 @@ def build_required_kwargs( # We are sure that we should skip manipulating this arg required_arg_names.pop(0) + if "say_stream" in required_arg_names: + warnings.warn( + "The say_stream listener argument is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, # Point to the caller, not this internal helper + ) + kwargs: Dict[str, Any] = {k: v for k, v in all_available_args.items() if k in required_arg_names} found_arg_names = kwargs.keys() for name in required_arg_names: diff --git a/slack_bolt/middleware/__init__.py b/slack_bolt/middleware/__init__.py index 0e4044f99..a68e4cdd8 100644 --- a/slack_bolt/middleware/__init__.py +++ b/slack_bolt/middleware/__init__.py @@ -17,6 +17,7 @@ from .ssl_check import SslCheck from .url_verification import UrlVerification from .attaching_function_token import AttachingFunctionToken +from .attaching_agent_kwargs import AttachingAgentKwargs builtin_middleware_classes = [ SslCheck, @@ -26,6 +27,7 @@ IgnoringSelfEvents, UrlVerification, AttachingFunctionToken, + AttachingAgentKwargs, # Assistant, # to avoid circular imports ] for cls in builtin_middleware_classes: @@ -41,5 +43,6 @@ "SslCheck", "UrlVerification", "AttachingFunctionToken", + "AttachingAgentKwargs", "builtin_middleware_classes", ] diff --git a/slack_bolt/middleware/async_builtins.py b/slack_bolt/middleware/async_builtins.py index d2d82c1fb..755b55c20 100644 --- a/slack_bolt/middleware/async_builtins.py +++ b/slack_bolt/middleware/async_builtins.py @@ -10,6 +10,7 @@ AsyncMessageListenerMatches, ) from .attaching_function_token.async_attaching_function_token import AsyncAttachingFunctionToken +from .attaching_agent_kwargs.async_attaching_agent_kwargs import AsyncAttachingAgentKwargs __all__ = [ "AsyncIgnoringSelfEvents", @@ -18,4 +19,5 @@ "AsyncUrlVerification", "AsyncMessageListenerMatches", "AsyncAttachingFunctionToken", + "AsyncAttachingAgentKwargs", ] diff --git a/slack_bolt/middleware/attaching_agent_kwargs/__init__.py b/slack_bolt/middleware/attaching_agent_kwargs/__init__.py new file mode 100644 index 000000000..98926fc14 --- /dev/null +++ b/slack_bolt/middleware/attaching_agent_kwargs/__init__.py @@ -0,0 +1,5 @@ +from .attaching_agent_kwargs import AttachingAgentKwargs + +__all__ = [ + "AttachingAgentKwargs", +] diff --git a/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py new file mode 100644 index 000000000..319e65b20 --- /dev/null +++ b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py @@ -0,0 +1,40 @@ +from typing import Callable, Awaitable + +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream +from slack_bolt.context.set_status.async_set_status import AsyncSetStatus +from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts +from slack_bolt.context.set_title.async_set_title import AsyncSetTitle +from slack_bolt.middleware.async_middleware import AsyncMiddleware +from slack_bolt.request.async_request import AsyncBoltRequest +from slack_bolt.response import BoltResponse + + +class AsyncAttachingAgentKwargs(AsyncMiddleware): + async def async_process( + self, + *, + req: AsyncBoltRequest, + resp: BoltResponse, + # This method is not supposed to be invoked by bolt-python users + next: Callable[[], Awaitable[BoltResponse]], + ) -> BoltResponse: + channel_id = req.context.channel_id + # TODO: improve the logic around extracting thread_ts and event ts + event = req.body.get("event", {}) + req.context.thread_ts + thread_ts = event.get("thread_ts") or event.get("ts") + + if channel_id and thread_ts: + client = req.context.client + req.context["set_status"] = AsyncSetStatus(client, channel_id, thread_ts) + req.context["set_title"] = AsyncSetTitle(client, channel_id, thread_ts) + req.context["set_suggested_prompts"] = AsyncSetSuggestedPrompts(client, channel_id, thread_ts) + req.context["say_stream"] = AsyncSayStream( + client=req.context.client, + channel_id=channel_id, + thread_ts=thread_ts, + team_id=req.context.team_id, + user_id=req.context.user_id, + ) + + return await next() diff --git a/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py new file mode 100644 index 000000000..66ad2c332 --- /dev/null +++ b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py @@ -0,0 +1,39 @@ +from typing import Callable + +from slack_bolt.context.say_stream.say_stream import SayStream +from slack_bolt.context.set_status.set_status import SetStatus +from slack_bolt.context.set_suggested_prompts.set_suggested_prompts import SetSuggestedPrompts +from slack_bolt.context.set_title.set_title import SetTitle +from slack_bolt.middleware.middleware import Middleware +from slack_bolt.request import BoltRequest +from slack_bolt.response import BoltResponse + + +class AttachingAgentKwargs(Middleware): + def process( + self, + *, + req: BoltRequest, + resp: BoltResponse, + # This method is not supposed to be invoked by bolt-python users + next: Callable[[], BoltResponse], + ) -> BoltResponse: + channel_id = req.context.channel_id + # TODO: improve the logic around extracting thread_ts and event ts + event = req.body.get("event", {}) + thread_ts = event.get("thread_ts") or event.get("ts") + + if channel_id and thread_ts: + client = req.context.client + req.context["set_status"] = SetStatus(client, channel_id, thread_ts) + req.context["set_title"] = SetTitle(client, channel_id, thread_ts) + req.context["set_suggested_prompts"] = SetSuggestedPrompts(client, channel_id, thread_ts) + req.context["say_stream"] = SayStream( + client=req.context.client, + channel_id=channel_id, + thread_ts=thread_ts, + team_id=req.context.team_id, + user_id=req.context.user_id, + ) + + return next() diff --git a/slack_bolt/request/internals.py b/slack_bolt/request/internals.py index e6a32db0d..e70fe2765 100644 --- a/slack_bolt/request/internals.py +++ b/slack_bolt/request/internals.py @@ -214,13 +214,11 @@ def extract_channel_id(payload: Dict[str, Any]) -> Optional[str]: return None +# TODO: EXTEND THIS CLASS SO THAT WE CAN USE IT IN SAY_STREAM def extract_thread_ts(payload: Dict[str, Any]) -> Optional[str]: # This utility initially supports only the use cases for AI assistants, but it may be fine to add more patterns. # That said, note that thread_ts is always required for assistant threads, but it's not for channels. # Thus, blindly setting this thread_ts to say utility can break existing apps' behaviors. - # - # The BoltAgent class handles non-assistant thread_ts separately by reading from the event directly, - # allowing it to work correctly without affecting say() behavior. if is_assistant_event(payload): event = payload["event"] if ( @@ -242,6 +240,16 @@ def extract_thread_ts(payload: Dict[str, Any]) -> Optional[str]: elif event.get("previous_message", {}).get("thread_ts") is not None: # message_deleted return event["previous_message"]["thread_ts"] + return None + thread_ts = payload.get("thread_ts") + if thread_ts is not None: + return thread_ts + if payload.get("event") is not None: + return extract_thread_ts(payload["event"]) + if isinstance(payload.get("message"), dict): + return extract_thread_ts(payload["message"]) + if isinstance(payload.get("previous_message"), dict): + return extract_thread_ts(payload["previous_message"]) return None diff --git a/tests/scenario_tests/test_events_agent.py b/tests/scenario_tests/test_events_agent.py index 667739728..9e7fcb105 100644 --- a/tests/scenario_tests/test_events_agent.py +++ b/tests/scenario_tests/test_events_agent.py @@ -1,12 +1,10 @@ import json from time import sleep -import pytest from slack_sdk.web import WebClient -from slack_bolt import App, BoltRequest, BoltContext, BoltAgent -from slack_bolt.agent.agent import BoltAgent as BoltAgentDirect -from slack_bolt.warning import ExperimentalWarning +from slack_bolt import App, BoltRequest, BoltContext, SayStream +from slack_bolt.context.say_stream.say_stream import SayStream as SayStreamDirect from tests.mock_web_api_server import ( setup_mock_web_api_server, cleanup_mock_web_api_server, @@ -14,6 +12,8 @@ from tests.utils import remove_os_env_temporarily, restore_os_env +# TODO: VALIDATE THIS AI SLOP IS CORRECT +# TODO: REMANE THIS FILE AND CLASS NAME class TestEventsAgent: valid_token = "xoxb-valid" mock_api_server_base_url = "http://localhost:8888" @@ -30,7 +30,7 @@ def teardown_method(self): cleanup_mock_web_api_server(self) restore_os_env(self.old_os_env) - def test_agent_injected_for_app_mention(self): + def test_say_stream_injected_for_app_mention(self): app = App(client=self.web_client) state = {"called": False} @@ -44,9 +44,9 @@ def assert_target_called(): state["called"] = False @app.event("app_mention") - def handle_mention(agent: BoltAgent, context: BoltContext): - assert agent is not None - assert isinstance(agent, BoltAgentDirect) + def handle_mention(say_stream: SayStream, context: BoltContext): + assert say_stream is not None + assert isinstance(say_stream, SayStreamDirect) assert context.channel_id == "C111" state["called"] = True @@ -55,7 +55,7 @@ def handle_mention(agent: BoltAgent, context: BoltContext): assert response.status == 200 assert_target_called() - def test_agent_available_in_action_listener(self): + def test_say_stream_not_available_in_action_listener(self): app = App(client=self.web_client) state = {"called": False} @@ -69,10 +69,9 @@ def assert_target_called(): state["called"] = False @app.action("test_action") - def handle_action(ack, agent: BoltAgent): + def handle_action(ack, say_stream: SayStream): ack() - assert agent is not None - assert isinstance(agent, BoltAgentDirect) + assert say_stream is None state["called"] = True request = BoltRequest(body=json.dumps(action_event_body), mode="socket_mode") @@ -80,29 +79,6 @@ def handle_action(ack, agent: BoltAgent): assert response.status == 200 assert_target_called() - def test_agent_kwarg_emits_experimental_warning(self): - app = App(client=self.web_client) - - state = {"called": False} - - def assert_target_called(): - count = 0 - while state["called"] is False and count < 20: - sleep(0.1) - count += 1 - assert state["called"] is True - state["called"] = False - - @app.event("app_mention") - def handle_mention(agent: BoltAgent): - state["called"] = True - - request = BoltRequest(body=app_mention_event_body, mode="socket_mode") - with pytest.warns(ExperimentalWarning, match="agent listener argument is experimental"): - response = app.dispatch(request) - assert response.status == 200 - assert_target_called() - # ---- Test event bodies ---- diff --git a/tests/scenario_tests_async/test_events_agent.py b/tests/scenario_tests_async/test_events_agent.py index 1702cdb61..952666aaf 100644 --- a/tests/scenario_tests_async/test_events_agent.py +++ b/tests/scenario_tests_async/test_events_agent.py @@ -4,11 +4,10 @@ import pytest from slack_sdk.web.async_client import AsyncWebClient -from slack_bolt.agent.async_agent import AsyncBoltAgent from slack_bolt.app.async_app import AsyncApp from slack_bolt.context.async_context import AsyncBoltContext +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream from slack_bolt.request.async_request import AsyncBoltRequest -from slack_bolt.warning import ExperimentalWarning from tests.mock_web_api_server import ( cleanup_mock_web_api_server_async, setup_mock_web_api_server_async, @@ -16,6 +15,8 @@ from tests.utils import remove_os_env_temporarily, restore_os_env +# TODO: VALIDATE THIS AI SLOP IS CORRECT +# TODO: REMANE THIS FILE AND CLASS NAME class TestAsyncEventsAgent: valid_token = "xoxb-valid" mock_api_server_base_url = "http://localhost:8888" @@ -35,7 +36,7 @@ def setup_teardown(self): restore_os_env(old_os_env) @pytest.mark.asyncio - async def test_agent_injected_for_app_mention(self): + async def test_say_stream_injected_for_app_mention(self): app = AsyncApp(client=self.web_client) state = {"called": False} @@ -49,9 +50,9 @@ async def assert_target_called(): state["called"] = False @app.event("app_mention") - async def handle_mention(agent: AsyncBoltAgent, context: AsyncBoltContext): - assert agent is not None - assert isinstance(agent, AsyncBoltAgent) + async def handle_mention(say_stream: AsyncSayStream, context: AsyncBoltContext): + assert say_stream is not None + assert isinstance(say_stream, AsyncSayStream) assert context.channel_id == "C111" state["called"] = True @@ -61,7 +62,7 @@ async def handle_mention(agent: AsyncBoltAgent, context: AsyncBoltContext): await assert_target_called() @pytest.mark.asyncio - async def test_agent_available_in_action_listener(self): + async def test_say_stream_not_available_in_action_listener(self): app = AsyncApp(client=self.web_client) state = {"called": False} @@ -75,10 +76,9 @@ async def assert_target_called(): state["called"] = False @app.action("test_action") - async def handle_action(ack, agent: AsyncBoltAgent): + async def handle_action(ack, say_stream: AsyncSayStream): await ack() - assert agent is not None - assert isinstance(agent, AsyncBoltAgent) + assert say_stream is None state["called"] = True request = AsyncBoltRequest(body=json.dumps(action_event_body), mode="socket_mode") @@ -86,30 +86,6 @@ async def handle_action(ack, agent: AsyncBoltAgent): assert response.status == 200 await assert_target_called() - @pytest.mark.asyncio - async def test_agent_kwarg_emits_experimental_warning(self): - app = AsyncApp(client=self.web_client) - - state = {"called": False} - - async def assert_target_called(): - count = 0 - while state["called"] is False and count < 20: - await asyncio.sleep(0.1) - count += 1 - assert state["called"] is True - state["called"] = False - - @app.event("app_mention") - async def handle_mention(agent: AsyncBoltAgent): - state["called"] = True - - request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") - with pytest.warns(ExperimentalWarning, match="agent listener argument is experimental"): - response = await app.async_dispatch(request) - assert response.status == 200 - await assert_target_called() - # ---- Test event bodies ---- diff --git a/tests/slack_bolt/agent/test_agent.py b/tests/slack_bolt/agent/test_agent.py deleted file mode 100644 index 76ac7d17b..000000000 --- a/tests/slack_bolt/agent/test_agent.py +++ /dev/null @@ -1,365 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from slack_sdk.web import WebClient -from slack_sdk.web.chat_stream import ChatStream - -from slack_bolt.agent.agent import BoltAgent - - -class TestBoltAgent: - def test_chat_stream_uses_context_defaults(self): - """BoltAgent.chat_stream() passes context defaults to WebClient.chat_stream().""" - client = MagicMock(spec=WebClient) - client.chat_stream.return_value = MagicMock(spec=ChatStream) - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - stream = agent.chat_stream() - - client.chat_stream.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - def test_chat_stream_overrides_context_defaults(self): - """Explicit kwargs to chat_stream() override context defaults.""" - client = MagicMock(spec=WebClient) - client.chat_stream.return_value = MagicMock(spec=ChatStream) - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - stream = agent.chat_stream( - channel="C999", - thread_ts="9999999999.999999", - recipient_team_id="T999", - recipient_user_id="U999", - ) - - client.chat_stream.assert_called_once_with( - channel="C999", - thread_ts="9999999999.999999", - recipient_team_id="T999", - recipient_user_id="U999", - ) - assert stream is not None - - def test_chat_stream_rejects_partial_overrides(self): - """Passing only some of the four context args raises ValueError.""" - client = MagicMock(spec=WebClient) - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(ValueError, match="Either provide all of"): - agent.chat_stream(channel="C999") - - def test_chat_stream_passes_extra_kwargs(self): - """Extra kwargs are forwarded to WebClient.chat_stream().""" - client = MagicMock(spec=WebClient) - client.chat_stream.return_value = MagicMock(spec=ChatStream) - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.chat_stream(buffer_size=512) - - client.chat_stream.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - buffer_size=512, - ) - - def test_chat_stream_falls_back_to_ts(self): - """When thread_ts is not set, chat_stream() falls back to ts.""" - client = MagicMock(spec=WebClient) - client.chat_stream.return_value = MagicMock(spec=ChatStream) - - agent = BoltAgent( - client=client, - channel_id="C111", - team_id="T111", - ts="1111111111.111111", - user_id="W222", - ) - stream = agent.chat_stream() - - client.chat_stream.assert_called_once_with( - channel="C111", - thread_ts="1111111111.111111", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - def test_chat_stream_prefers_thread_ts_over_ts(self): - """thread_ts takes priority over ts.""" - client = MagicMock(spec=WebClient) - client.chat_stream.return_value = MagicMock(spec=ChatStream) - - agent = BoltAgent( - client=client, - channel_id="C111", - team_id="T111", - thread_ts="1234567890.123456", - ts="1111111111.111111", - user_id="W222", - ) - stream = agent.chat_stream() - - client.chat_stream.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - def test_set_status_uses_context_defaults(self): - """BoltAgent.set_status() passes context defaults to WebClient.assistant_threads_setStatus().""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setStatus.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_status(status="Thinking...") - - client.assistant_threads_setStatus.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=None, - ) - - def test_set_status_with_loading_messages(self): - """BoltAgent.set_status() forwards loading_messages.""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setStatus.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_status( - status="Thinking...", - loading_messages=["Sitting...", "Waiting..."], - ) - - client.assistant_threads_setStatus.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=["Sitting...", "Waiting..."], - ) - - def test_set_status_overrides_context_defaults(self): - """Explicit channel_id/thread_ts override context defaults.""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setStatus.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_status( - status="Thinking...", - channel_id="C999", - thread_ts="9999999999.999999", - ) - - client.assistant_threads_setStatus.assert_called_once_with( - channel_id="C999", - thread_ts="9999999999.999999", - status="Thinking...", - loading_messages=None, - ) - - def test_set_status_passes_extra_kwargs(self): - """Extra kwargs are forwarded to WebClient.assistant_threads_setStatus().""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setStatus.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_status(status="Thinking...", token="xoxb-override") - - client.assistant_threads_setStatus.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=None, - token="xoxb-override", - ) - - def test_set_status_requires_status(self): - """set_status() raises TypeError when status is not provided.""" - client = MagicMock(spec=WebClient) - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(TypeError): - agent.set_status() - - def test_set_suggested_prompts_uses_context_defaults(self): - """BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts().""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"]) - - client.assistant_threads_setSuggestedPrompts.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[ - {"title": "What can you do?", "message": "What can you do?"}, - {"title": "Help me write code", "message": "Help me write code"}, - ], - title=None, - ) - - def test_set_suggested_prompts_with_dict_prompts(self): - """BoltAgent.set_suggested_prompts() accepts dict prompts with title and message.""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_suggested_prompts( - prompts=[ - {"title": "Short title", "message": "A much longer message for this prompt"}, - ], - title="Suggestions", - ) - - client.assistant_threads_setSuggestedPrompts.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[ - {"title": "Short title", "message": "A much longer message for this prompt"}, - ], - title="Suggestions", - ) - - def test_set_suggested_prompts_overrides_context_defaults(self): - """Explicit channel_id/thread_ts override context defaults.""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_suggested_prompts( - prompts=["Hello"], - channel_id="C999", - thread_ts="9999999999.999999", - ) - - client.assistant_threads_setSuggestedPrompts.assert_called_once_with( - channel_id="C999", - thread_ts="9999999999.999999", - prompts=[{"title": "Hello", "message": "Hello"}], - title=None, - ) - - def test_set_suggested_prompts_passes_extra_kwargs(self): - """Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts().""" - client = MagicMock(spec=WebClient) - client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() - - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override") - - client.assistant_threads_setSuggestedPrompts.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[{"title": "Hello", "message": "Hello"}], - title=None, - token="xoxb-override", - ) - - def test_set_suggested_prompts_requires_prompts(self): - """set_suggested_prompts() raises TypeError when prompts is not provided.""" - client = MagicMock(spec=WebClient) - agent = BoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(TypeError): - agent.set_suggested_prompts() - - def test_import_from_slack_bolt(self): - from slack_bolt import BoltAgent as ImportedBoltAgent - - assert ImportedBoltAgent is BoltAgent - - def test_import_from_agent_module(self): - from slack_bolt.agent import BoltAgent as ImportedBoltAgent - - assert ImportedBoltAgent is BoltAgent diff --git a/tests/slack_bolt/context/test_say_stream.py b/tests/slack_bolt/context/test_say_stream.py new file mode 100644 index 000000000..be8cec812 --- /dev/null +++ b/tests/slack_bolt/context/test_say_stream.py @@ -0,0 +1,107 @@ +from unittest.mock import MagicMock + +import pytest +from slack_sdk.web import WebClient +from slack_sdk.web.chat_stream import ChatStream + +from slack_bolt.context.say_stream.say_stream import SayStream + + +# TODO: VALIDATE THIS AI SLOP IS CORRECT +class TestSayStream: + def test_uses_context_defaults(self): + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + say_stream = SayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = say_stream() + + client.chat_stream.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + ) + assert stream is not None + + def test_overrides_context_defaults(self): + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + say_stream = SayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = say_stream( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + + client.chat_stream.assert_called_once_with( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + assert stream is not None + + def test_rejects_partial_overrides(self): + client = MagicMock(spec=WebClient) + say_stream = SayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(ValueError, match="Either provide all of"): + say_stream(channel="C999") + + def test_passes_extra_kwargs(self): + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + say_stream = SayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + say_stream(buffer_size=512) + + client.chat_stream.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + buffer_size=512, + ) + + def test_raises_without_channel_id(self): + client = MagicMock(spec=WebClient) + say_stream = SayStream(client=client, channel_id=None, thread_ts="1234567890.123456") + with pytest.raises(ValueError, match="no channel_id"): + say_stream() + + def test_raises_without_thread_ts(self): + client = MagicMock(spec=WebClient) + say_stream = SayStream(client=client, channel_id="C111", thread_ts=None) + with pytest.raises(ValueError, match="no thread_ts"): + say_stream() + + def test_import_from_slack_bolt(self): + from slack_bolt import SayStream as ImportedSayStream + + assert ImportedSayStream is SayStream diff --git a/tests/slack_bolt/agent/__init__.py b/tests/slack_bolt/middleware/attaching_agent_kwargs/__init__.py similarity index 100% rename from tests/slack_bolt/agent/__init__.py rename to tests/slack_bolt/middleware/attaching_agent_kwargs/__init__.py diff --git a/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py b/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py new file mode 100644 index 000000000..5811a9def --- /dev/null +++ b/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py @@ -0,0 +1,192 @@ +from time import sleep + +from slack_sdk.web import WebClient + +from slack_bolt import App, BoltRequest +from slack_bolt.context.set_status.set_status import SetStatus +from slack_bolt.context.set_suggested_prompts.set_suggested_prompts import SetSuggestedPrompts +from slack_bolt.context.set_title.set_title import SetTitle +from tests.mock_web_api_server import ( + setup_mock_web_api_server, + cleanup_mock_web_api_server, +) +from tests.utils import remove_os_env_temporarily, restore_os_env + + +# TODO: VALIDATE THIS AI SLOP IS CORRECT +def build_event_body(event: dict) -> dict: + return { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": event, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1599616881, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": "T111", + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": False, + } + ], + } + + +top_level_message_body = build_event_body( + { + "type": "message", + "user": "W222", + "text": "hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +threaded_message_body = build_event_body( + { + "type": "message", + "user": "W222", + "text": "hello in thread", + "ts": "1234567890.999999", + "thread_ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.999999", + } +) + +app_mention_body = build_event_body( + { + "type": "app_mention", + "user": "W222", + "text": "<@W111> hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +no_channel_event_body = build_event_body( + { + "type": "team_join", + "user": {"id": "W222"}, + } +) + + +class TestAttachingAgentKwargs: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = WebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + def setup_method(self): + self.old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server(self) + + def teardown_method(self): + cleanup_mock_web_api_server(self) + restore_os_env(self.old_os_env) + + def _wait_for(self, state, key="called", timeout=2.0): + count = 0 + while state[key] is False and count < timeout / 0.1: + sleep(0.1) + count += 1 + assert state[key] is True + + def test_top_level_message_uses_ts(self): + app = App(client=self.web_client) + state = {"called": False} + + @app.event("message") + def handle(context): + assert isinstance(context["set_status"], SetStatus) + assert isinstance(context["set_title"], SetTitle) + assert isinstance(context["set_suggested_prompts"], SetSuggestedPrompts) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = app.dispatch(BoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) + + def test_threaded_message_uses_thread_ts(self): + app = App(client=self.web_client) + state = {"called": False} + + @app.event("message") + def handle(context): + assert isinstance(context["set_status"], SetStatus) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = app.dispatch(BoltRequest(body=threaded_message_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) + + def test_app_mention_event(self): + app = App(client=self.web_client) + state = {"called": False} + + @app.event("app_mention") + def handle(context): + assert isinstance(context["set_status"], SetStatus) + assert isinstance(context["set_title"], SetTitle) + assert isinstance(context["set_suggested_prompts"], SetSuggestedPrompts) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = app.dispatch(BoltRequest(body=app_mention_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) + + def test_message_listener_top_level(self): + app = App(client=self.web_client) + state = {"called": False} + + @app.message("hello") + def handle(context): + assert isinstance(context["set_status"], SetStatus) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = app.dispatch(BoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) + + def test_no_channel_id_skips_gracefully(self): + app = App(client=self.web_client) + state = {"called": False} + + @app.event("team_join") + def handle(context): + assert "set_status" not in context + assert "set_title" not in context + assert "set_suggested_prompts" not in context + state["called"] = True + + response = app.dispatch(BoltRequest(body=no_channel_event_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) + + def test_opt_out(self): + app = App(client=self.web_client, attaching_agent_kwargs_enabled=False) + state = {"called": False} + + @app.event("message") + def handle(context): + assert "set_status" not in context + assert "set_title" not in context + assert "set_suggested_prompts" not in context + state["called"] = True + + response = app.dispatch(BoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + self._wait_for(state) diff --git a/tests/slack_bolt/request/test_internals.py b/tests/slack_bolt/request/test_internals.py index 752fa6d2d..8a0deca5e 100644 --- a/tests/slack_bolt/request/test_internals.py +++ b/tests/slack_bolt/request/test_internals.py @@ -13,6 +13,7 @@ extract_actor_team_id, extract_actor_user_id, extract_function_execution_id, + extract_thread_ts, ) @@ -111,6 +112,143 @@ def teardown_method(self): }, ] + thread_ts_event_requests = [ + { + "event": { + "type": "app_mention", + "channel": "C111", + "user": "U111", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + { + "event": { + "type": "message", + "channel": "C111", + "user": "U111", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + { + "event": { + "type": "message", + "subtype": "bot_message", + "channel": "C111", + "bot_id": "B111", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + { + "event": { + "type": "message", + "subtype": "file_share", + "channel": "C111", + "user": "U111", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + { + "event": { + "type": "message", + "subtype": "thread_broadcast", + "channel": "C111", + "user": "U111", + "ts": "123.420", + "thread_ts": "123.456", + "root": {"thread_ts": "123.420"}, + }, + }, + { + "event": { + "type": "link_shared", + "channel": "C111", + "user": "U111", + "thread_ts": "123.456", + "links": [{"url": "https://example.com"}], + }, + }, + { + "event": { + "type": "message", + "subtype": "message_changed", + "channel": "C111", + "message": { + "type": "message", + "user": "U111", + "text": "edited", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + }, + { + "event": { + "type": "message", + "subtype": "message_changed", + "channel": "C111", + "message": { + "type": "message", + "user": "U111", + "text": "edited", + "ts": "123.420", + "thread_ts": "123.456", + }, + "previous_message": { + "type": "message", + "user": "U111", + "text": "deleted", + "ts": "123.420", + "thread_ts": "123.420", + }, + }, + }, + { + "event": { + "type": "message", + "subtype": "message_deleted", + "channel": "C111", + "previous_message": { + "type": "message", + "user": "U111", + "text": "deleted", + "ts": "123.420", + "thread_ts": "123.456", + }, + }, + }, + ] + + no_thread_ts_requests = [ + { + "event": { + "type": "reaction_added", + "user": "U111", + "reaction": "thumbsup", + "item": {"type": "message", "channel": "C111", "ts": "123.420"}, + }, + }, + { + "event": { + "type": "channel_created", + "channel": {"id": "C222", "name": "test", "created": 1678455198}, + }, + }, + { + "event": { + "type": "message", + "channel": "C111", + "user": "U111", + "text": "hello", + "ts": "123.420", + }, + }, + {}, + ] + slack_connect_authorizations = [ { "enterprise_id": "INSTALLED_ENTERPRISE_ID", @@ -223,10 +361,10 @@ def teardown_method(self): "type": "message", "text": "<@INSTALLED_BOT_USER_ID> Hey!", "user": "USER_ID_ACTOR", - "ts": "1678455198.838499", + "ts": "123.456", "team": "TEAM_ID_ACTOR", "channel": "C111", - "event_ts": "1678455198.838499", + "event_ts": "123.456", "channel_type": "channel", }, "type": "event_callback", @@ -337,6 +475,47 @@ def test_function_inputs_extraction(self): inputs = extract_function_inputs(req) assert inputs == {"customer_id": "Ux111"} + def test_extract_thread_ts(self): + for req in self.thread_ts_event_requests: + thread_ts = extract_thread_ts(req) + assert thread_ts == "123.456", f"Expected thread_ts for {req}" + + def test_extract_thread_ts_fail(self): + for req in self.no_thread_ts_requests: + thread_ts = extract_thread_ts(req) + assert thread_ts is None, f"Expected None for {req}" + + def test_extract_thread_ts_edge_cases(self): + # message_changed where only previous_message has thread_ts (no message key) + req = { + "event": { + "type": "message", + "subtype": "message_deleted", + "channel": "C111", + "previous_message": { + "type": "message", + "ts": "1678455205.000000", + "thread_ts": "123.456", + }, + }, + } + assert extract_thread_ts(req) == "123.456" + + # Payload with thread_ts directly at root level (non-event payload) + req = {"thread_ts": "123.456"} + assert extract_thread_ts(req) == "123.456" + + # Event with thread_ts as empty string (truthy check: empty string is falsy) + req = { + "event": { + "type": "message", + "channel": "C111", + "thread_ts": "", + }, + } + # Empty string is falsy, so .get() returns "" but `is not None` is True + assert extract_thread_ts(req) == "" + def test_is_enterprise_install_extraction(self): for req in self.requests: should_be_false = extract_is_enterprise_install(req) diff --git a/tests/slack_bolt_async/agent/test_async_agent.py b/tests/slack_bolt_async/agent/test_async_agent.py deleted file mode 100644 index 3ed8ef0b4..000000000 --- a/tests/slack_bolt_async/agent/test_async_agent.py +++ /dev/null @@ -1,399 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from slack_sdk.web.async_client import AsyncWebClient -from slack_sdk.web.async_chat_stream import AsyncChatStream - -from slack_bolt.agent.async_agent import AsyncBoltAgent - - -def _make_async_chat_stream_mock(): - mock_stream = MagicMock(spec=AsyncChatStream) - call_tracker = MagicMock() - - async def fake_chat_stream(**kwargs): - call_tracker(**kwargs) - return mock_stream - - return fake_chat_stream, call_tracker, mock_stream - - -def _make_async_api_mock(): - mock_response = MagicMock() - call_tracker = MagicMock() - - async def fake_api_call(**kwargs): - call_tracker(**kwargs) - return mock_response - - return fake_api_call, call_tracker, mock_response - - -class TestAsyncBoltAgent: - @pytest.mark.asyncio - async def test_chat_stream_uses_context_defaults(self): - """AsyncBoltAgent.chat_stream() passes context defaults to AsyncWebClient.chat_stream().""" - client = MagicMock(spec=AsyncWebClient) - client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - stream = await agent.chat_stream() - - call_tracker.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - @pytest.mark.asyncio - async def test_chat_stream_overrides_context_defaults(self): - """Explicit kwargs to chat_stream() override context defaults.""" - client = MagicMock(spec=AsyncWebClient) - client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - stream = await agent.chat_stream( - channel="C999", - thread_ts="9999999999.999999", - recipient_team_id="T999", - recipient_user_id="U999", - ) - - call_tracker.assert_called_once_with( - channel="C999", - thread_ts="9999999999.999999", - recipient_team_id="T999", - recipient_user_id="U999", - ) - assert stream is not None - - @pytest.mark.asyncio - async def test_chat_stream_rejects_partial_overrides(self): - """Passing only some of the four context args raises ValueError.""" - client = MagicMock(spec=AsyncWebClient) - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(ValueError, match="Either provide all of"): - await agent.chat_stream(channel="C999") - - @pytest.mark.asyncio - async def test_chat_stream_passes_extra_kwargs(self): - """Extra kwargs are forwarded to AsyncWebClient.chat_stream().""" - client = MagicMock(spec=AsyncWebClient) - client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.chat_stream(buffer_size=512) - - call_tracker.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - buffer_size=512, - ) - - @pytest.mark.asyncio - async def test_chat_stream_falls_back_to_ts(self): - """When thread_ts is not set, chat_stream() falls back to ts.""" - client = MagicMock(spec=AsyncWebClient) - client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - team_id="T111", - ts="1111111111.111111", - user_id="W222", - ) - stream = await agent.chat_stream() - - call_tracker.assert_called_once_with( - channel="C111", - thread_ts="1111111111.111111", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - @pytest.mark.asyncio - async def test_chat_stream_prefers_thread_ts_over_ts(self): - """thread_ts takes priority over ts.""" - client = MagicMock(spec=AsyncWebClient) - client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - team_id="T111", - thread_ts="1234567890.123456", - ts="1111111111.111111", - user_id="W222", - ) - stream = await agent.chat_stream() - - call_tracker.assert_called_once_with( - channel="C111", - thread_ts="1234567890.123456", - recipient_team_id="T111", - recipient_user_id="W222", - ) - assert stream is not None - - @pytest.mark.asyncio - async def test_set_status_uses_context_defaults(self): - """AsyncBoltAgent.set_status() passes context defaults to AsyncWebClient.assistant_threads_setStatus().""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_status(status="Thinking...") - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=None, - ) - - @pytest.mark.asyncio - async def test_set_status_with_loading_messages(self): - """AsyncBoltAgent.set_status() forwards loading_messages.""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_status( - status="Thinking...", - loading_messages=["Sitting...", "Waiting..."], - ) - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=["Sitting...", "Waiting..."], - ) - - @pytest.mark.asyncio - async def test_set_status_overrides_context_defaults(self): - """Explicit channel_id/thread_ts override context defaults.""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_status( - status="Thinking...", - channel_id="C999", - thread_ts="9999999999.999999", - ) - - call_tracker.assert_called_once_with( - channel_id="C999", - thread_ts="9999999999.999999", - status="Thinking...", - loading_messages=None, - ) - - @pytest.mark.asyncio - async def test_set_status_passes_extra_kwargs(self): - """Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setStatus().""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setStatus, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_status(status="Thinking...", token="xoxb-override") - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - status="Thinking...", - loading_messages=None, - token="xoxb-override", - ) - - @pytest.mark.asyncio - async def test_set_status_requires_status(self): - """set_status() raises TypeError when status is not provided.""" - client = MagicMock(spec=AsyncWebClient) - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(TypeError): - await agent.set_status() - - @pytest.mark.asyncio - async def test_set_suggested_prompts_uses_context_defaults(self): - """AsyncBoltAgent.set_suggested_prompts() passes context defaults to AsyncWebClient.assistant_threads_setSuggestedPrompts().""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"]) - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[ - {"title": "What can you do?", "message": "What can you do?"}, - {"title": "Help me write code", "message": "Help me write code"}, - ], - title=None, - ) - - @pytest.mark.asyncio - async def test_set_suggested_prompts_with_dict_prompts(self): - """AsyncBoltAgent.set_suggested_prompts() accepts dict prompts with title and message.""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_suggested_prompts( - prompts=[ - {"title": "Short title", "message": "A much longer message for this prompt"}, - ], - title="Suggestions", - ) - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[ - {"title": "Short title", "message": "A much longer message for this prompt"}, - ], - title="Suggestions", - ) - - @pytest.mark.asyncio - async def test_set_suggested_prompts_overrides_context_defaults(self): - """Explicit channel_id/thread_ts override context defaults.""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_suggested_prompts( - prompts=["Hello"], - channel_id="C999", - thread_ts="9999999999.999999", - ) - - call_tracker.assert_called_once_with( - channel_id="C999", - thread_ts="9999999999.999999", - prompts=[{"title": "Hello", "message": "Hello"}], - title=None, - ) - - @pytest.mark.asyncio - async def test_set_suggested_prompts_passes_extra_kwargs(self): - """Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setSuggestedPrompts().""" - client = MagicMock(spec=AsyncWebClient) - client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() - - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - await agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override") - - call_tracker.assert_called_once_with( - channel_id="C111", - thread_ts="1234567890.123456", - prompts=[{"title": "Hello", "message": "Hello"}], - title=None, - token="xoxb-override", - ) - - @pytest.mark.asyncio - async def test_set_suggested_prompts_requires_prompts(self): - """set_suggested_prompts() raises TypeError when prompts is not provided.""" - client = MagicMock(spec=AsyncWebClient) - agent = AsyncBoltAgent( - client=client, - channel_id="C111", - thread_ts="1234567890.123456", - team_id="T111", - user_id="W222", - ) - with pytest.raises(TypeError): - await agent.set_suggested_prompts() - - @pytest.mark.asyncio - async def test_import_from_agent_module(self): - from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent - - assert ImportedAsyncBoltAgent is AsyncBoltAgent diff --git a/tests/slack_bolt_async/context/test_async_say_stream.py b/tests/slack_bolt_async/context/test_async_say_stream.py new file mode 100644 index 000000000..c52901d01 --- /dev/null +++ b/tests/slack_bolt_async/context/test_async_say_stream.py @@ -0,0 +1,119 @@ +from unittest.mock import MagicMock + +import pytest +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_chat_stream import AsyncChatStream + +from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream + + +# TODO: VALIDATE THIS AI SLOP IS CORRECT +def _make_async_chat_stream_mock(): + mock_stream = MagicMock(spec=AsyncChatStream) + call_tracker = MagicMock() + + async def fake_chat_stream(**kwargs): + call_tracker(**kwargs) + return mock_stream + + return fake_chat_stream, call_tracker, mock_stream + + +class TestAsyncSayStream: + @pytest.mark.asyncio + async def test_uses_context_defaults(self): + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + say_stream = AsyncSayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = await say_stream() + + call_tracker.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + ) + assert stream is not None + + @pytest.mark.asyncio + async def test_overrides_context_defaults(self): + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + say_stream = AsyncSayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = await say_stream( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + + call_tracker.assert_called_once_with( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + assert stream is not None + + @pytest.mark.asyncio + async def test_rejects_partial_overrides(self): + client = MagicMock(spec=AsyncWebClient) + say_stream = AsyncSayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(ValueError, match="Either provide all of"): + await say_stream(channel="C999") + + @pytest.mark.asyncio + async def test_passes_extra_kwargs(self): + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + say_stream = AsyncSayStream( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await say_stream(buffer_size=512) + + call_tracker.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + buffer_size=512, + ) + + @pytest.mark.asyncio + async def test_raises_without_channel_id(self): + client = MagicMock(spec=AsyncWebClient) + say_stream = AsyncSayStream(client=client, channel_id=None, thread_ts="1234567890.123456") + with pytest.raises(ValueError, match="no channel_id"): + await say_stream() + + @pytest.mark.asyncio + async def test_raises_without_thread_ts(self): + client = MagicMock(spec=AsyncWebClient) + say_stream = AsyncSayStream(client=client, channel_id="C111", thread_ts=None) + with pytest.raises(ValueError, match="no thread_ts"): + await say_stream() diff --git a/tests/slack_bolt_async/agent/__init__.py b/tests/slack_bolt_async/middleware/attaching_agent_kwargs/__init__.py similarity index 100% rename from tests/slack_bolt_async/agent/__init__.py rename to tests/slack_bolt_async/middleware/attaching_agent_kwargs/__init__.py diff --git a/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py b/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py new file mode 100644 index 000000000..8250ddc23 --- /dev/null +++ b/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py @@ -0,0 +1,202 @@ +import asyncio + +import pytest +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt.app.async_app import AsyncApp +from slack_bolt.context.set_status.async_set_status import AsyncSetStatus +from slack_bolt.context.set_suggested_prompts.async_set_suggested_prompts import AsyncSetSuggestedPrompts +from slack_bolt.context.set_title.async_set_title import AsyncSetTitle +from slack_bolt.request.async_request import AsyncBoltRequest +from tests.mock_web_api_server import ( + cleanup_mock_web_api_server_async, + setup_mock_web_api_server_async, +) +from tests.utils import remove_os_env_temporarily, restore_os_env + + +# TODO: VALIDATE THIS AI SLOP IS CORRECT +def build_event_body(event: dict) -> dict: + return { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": event, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1599616881, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": "T111", + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": False, + } + ], + } + + +top_level_message_body = build_event_body( + { + "type": "message", + "user": "W222", + "text": "hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +threaded_message_body = build_event_body( + { + "type": "message", + "user": "W222", + "text": "hello in thread", + "ts": "1234567890.999999", + "thread_ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.999999", + } +) + +app_mention_body = build_event_body( + { + "type": "app_mention", + "user": "W222", + "text": "<@W111> hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +no_channel_event_body = build_event_body( + { + "type": "team_join", + "user": {"id": "W222"}, + } +) + + +class TestAsyncAttachingAgentKwargs: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = AsyncWebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + @pytest.fixture(scope="function", autouse=True) + def setup_teardown(self): + old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server_async(self) + try: + yield + finally: + cleanup_mock_web_api_server_async(self) + restore_os_env(old_os_env) + + async def _wait_for(self, state, key="called", timeout=2.0): + count = 0 + while state[key] is False and count < timeout / 0.1: + await asyncio.sleep(0.1) + count += 1 + assert state[key] is True + + @pytest.mark.asyncio + async def test_top_level_message_uses_ts(self): + app = AsyncApp(client=self.web_client) + state = {"called": False} + + @app.event("message") + async def handle(context): + assert isinstance(context["set_status"], AsyncSetStatus) + assert isinstance(context["set_title"], AsyncSetTitle) + assert isinstance(context["set_suggested_prompts"], AsyncSetSuggestedPrompts) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state) + + @pytest.mark.asyncio + async def test_threaded_message_uses_thread_ts(self): + app = AsyncApp(client=self.web_client) + state = {"called": False} + + @app.event("message") + async def handle(context): + assert isinstance(context["set_status"], AsyncSetStatus) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=threaded_message_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state) + + @pytest.mark.asyncio + async def test_app_mention_event(self): + app = AsyncApp(client=self.web_client) + state = {"called": False} + + @app.event("app_mention") + async def handle(context): + assert isinstance(context["set_status"], AsyncSetStatus) + assert isinstance(context["set_title"], AsyncSetTitle) + assert isinstance(context["set_suggested_prompts"], AsyncSetSuggestedPrompts) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=app_mention_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state) + + @pytest.mark.asyncio + async def test_message_listener_top_level(self): + app = AsyncApp(client=self.web_client) + state = {"called": False} + + @app.message("hello") + async def handle(context): + assert isinstance(context["set_status"], AsyncSetStatus) + assert context["set_status"].thread_ts == "1234567890.123456" + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state) + + @pytest.mark.asyncio + async def test_no_channel_id_skips_gracefully(self): + app = AsyncApp(client=self.web_client) + state = {"called": False} + + @app.event("team_join") + async def handle(context): + assert "set_status" not in context + assert "set_title" not in context + assert "set_suggested_prompts" not in context + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=no_channel_event_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state) + + @pytest.mark.asyncio + async def test_opt_out(self): + app = AsyncApp(client=self.web_client, attaching_agent_kwargs_enabled=False) + state = {"called": False} + + @app.event("message") + async def handle(context): + assert "set_status" not in context + assert "set_title" not in context + assert "set_suggested_prompts" not in context + state["called"] = True + + response = await app.async_dispatch(AsyncBoltRequest(body=top_level_message_body, mode="socket_mode")) + assert response.status == 200 + await self._wait_for(state)