diff --git a/src/sap_cloud_sdk/extensibility/client.py b/src/sap_cloud_sdk/extensibility/client.py index 01267bee..5066f12a 100644 --- a/src/sap_cloud_sdk/extensibility/client.py +++ b/src/sap_cloud_sdk/extensibility/client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import itertools import json import logging @@ -15,6 +16,7 @@ from sap_cloud_sdk.core.telemetry import Module, Operation from sap_cloud_sdk.core.telemetry.metrics_decorator import record_metrics +from sap_cloud_sdk.agentgateway import create_client as create_agw_client from sap_cloud_sdk.extensibility._models import ( DEFAULT_EXTENSION_CAPABILITY_ID, ExtensionCapabilityImplementation, @@ -38,6 +40,7 @@ _EXECUTE_WORKFLOW_TOOL_NAME = "execute_workflow" _GET_EXECUTION_TOOL_NAME = "get_execution" +_N8N_MCP_SERVER_NAME = "sap.btpn8n:apiResource:ManagedN8nMcpServer:v1" _JSONRPC_VERSION = "2.0" @@ -220,34 +223,34 @@ def call_hook( hook_config: HookConfig, ) -> Optional[Message]: """Call a hook's MCP endpoint and poll until completion. - + Executes the workflow via ``execute-workflow``, then polls ``get-execution`` every 500 ms until the execution succeeds, fails, or ``hook.timeout`` seconds elapse. - + This method is transport-agnostic: regardless of how extension metadata was fetched (backend, local file, or no-op), the actual hook invocation is always a direct HTTP call to the URL embedded in the :class:`Hook` object. - + Args: hook: Hook configuration (workflow ID, method, timeout). hook_config: Hook invocation configuration (endpoint URL, auth token, optional payload). - + Returns: Parsed ``Message`` from the last executed workflow node, or ``None`` if the hook completed successfully but produced no message. - + Raises: TransportError: On HTTP errors, terminal execution failures, or timeout. - + Example: ```python from sap_cloud_sdk.extensibility import create_client - + client = create_client("sap.ai:agent:myAgent:v1") impl = client.get_extension_capability_implementation(tenant="tenant-abc") - + if impl.hooks: hook = impl.hooks[0] result = client.call_hook( @@ -262,13 +265,13 @@ def call_hook( """ headers = {**_JSONRPC_HEADERS} inject(headers) - + message_payload: dict[str, Any] = {} if hook_config.payload is not None: model_dump = getattr(hook_config.payload, "model_dump", None) if callable(model_dump): message_payload = cast(dict[str, Any], model_dump(exclude_none=True)) - + # 1. Execute workflow execute_workflow_arguments = { "workflowId": hook.n8n_workflow_config.workflow_id, @@ -282,7 +285,7 @@ def call_hook( }, }, } - + try: with httpx.Client( headers={"Authorization": f"Bearer {hook_config.auth_token}"}, @@ -301,16 +304,16 @@ def call_hook( raise TransportError( f"HTTP request to hook MCP endpoint failed: {exc}" ) from exc - + try: data = _extract_tool_result(_parse_response(tool_resp)) except TransportError: raise except Exception as exc: raise TransportError(f"Could not parse hook response: {exc}") from exc - + status = data.get("status", "") - + # 2. Fail fast on terminal statuses from execute-workflow if status in _EXECUTE_TERMINAL_STATUSES: error_msg = data.get("error", "") @@ -318,7 +321,7 @@ def call_hook( f"Workflow execution failed with status {status!r}" + (f": {error_msg}" if error_msg else "") ) - + # 3. Return immediately if execution completed synchronously if status == "success": try: @@ -336,7 +339,7 @@ def call_hook( raise TransportError( f"Failed to extract response from last executed node: {exc}" ) from exc - + # 4. Poll get-execution for running/new/waiting/started execution_id = data.get("executionId") get_execution_arguments = { @@ -344,12 +347,12 @@ def call_hook( "executionId": str(execution_id), "includeData": True, } - + deadline = time.monotonic() + hook.timeout last_status = status while time.monotonic() < deadline: time.sleep(_HOOK_POLL_INTERVAL) - + try: with httpx.Client( headers={"Authorization": f"Bearer {hook_config.auth_token}"}, @@ -368,13 +371,215 @@ def call_hook( raise TransportError( f"HTTP request to hook MCP endpoint failed: {exc}" ) from exc - + try: data = _extract_tool_result(_parse_response(tool_resp)) except TransportError: raise except Exception as exc: raise TransportError(f"Could not parse hook response: {exc}") from exc + + last_status = data.get("execution", {}).get("status", "") or data.get( + "status", "" + ) + + if last_status == "success": + try: + result_data = data.get("data", {}).get("resultData", {}) + last_node = result_data.get("lastNodeExecuted", "") + response_json = ( + result_data.get("runData", {}) + .get(last_node, [{}])[0] + .get("data", {}) + .get("main", [[{}]])[0][0] + .get("json", {}) + ) + return Message(**response_json) + except (KeyError, IndexError, TypeError, ValidationError) as exc: + raise TransportError( + f"Failed to extract response from last executed node: {exc}" + ) from exc + + if last_status in _EXECUTION_TERMINAL_STATUSES: + error_msg = data.get("error", "") + raise TransportError( + f"Workflow execution failed with status {last_status!r}" + + (f": {error_msg}" if error_msg else "") + ) + + # Continue polling for: running, waiting, new, unknown + + raise TransportError( + f"Workflow execution timed out after {hook.timeout}s. " + f"Last status: {last_status!r}" + ) + + + @record_metrics( + Module.EXTENSIBILITY, + Operation.EXTENSIBILITY_CALL_HOOK, + ) + async def call_hook( + self, + hook: Hook, + user_token: Optional[str] = None, + message: Optional[Any] = None, + headers: Optional[dict] = None, + tenant_subdomain: Optional[str] = None + ) -> Optional[Message]: + """Call a hook via Agent Gateway MCP tool invocation. + + Discovers the N8N MCP tools via Agent Gateway, executes the workflow via + ``execute_workflow``, then polls ``get_execution`` every 500 ms until the + execution succeeds, fails, or ``hook.timeout`` seconds elapse. + + Auth and endpoint resolution are handled internally by the AGW client — + no manual token or URL configuration is required. + + Args: + hook: Hook configuration (workflow ID, method, timeout). + agw_client: Configured Agent Gateway client used for tool discovery + and invocation. + + Returns: + Parsed ``Message`` from the last executed workflow node, or ``None`` + if the hook completed successfully but produced no message. + + Raises: + TransportError: On AGW tool call errors, terminal execution failures, + or timeout. + + Example: + ```python + from sap_cloud_sdk.extensibility import call_hook + from sap_cloud_sdk.agentgateway import create_client as create_agw_client + + agw_client = create_agw_client(tenant_subdomain="my-tenant") + + result = await call_hook( + hook=impl.hooks[0], + agw_client=agw_client, + ) + ``` + """ + # 1. Create AGW client for the given tenant subdomain. + agw_client = None + agw_client = create_agw_client(tenant_subdomain) + + # 2. Discover MCP tools — AGW resolves N8N GTID and handles auth internally + # TODO: Cache the list of mcp tools for performance. + tools = await agw_client.list_mcp_tools(user_token=user_token or None) + + execute_tool = next( + ( + t for t in tools + if t.name == _EXECUTE_WORKFLOW_TOOL_NAME and t.server_name == _N8N_MCP_SERVER_NAME + ), + None, + ) + if execute_tool is None: + raise TransportError( + f"MCP tool '{_EXECUTE_WORKFLOW_TOOL_NAME}' on server '{_N8N_MCP_SERVER_NAME}' " + "not found via Agent Gateway." + ) + + get_exec_tool = next( + ( + t for t in tools + if t.name == _GET_EXECUTION_TOOL_NAME and t.server_name == _N8N_MCP_SERVER_NAME + ), + None, + ) + if get_exec_tool is None: + raise TransportError( + f"MCP tool '{_GET_EXECUTION_TOOL_NAME}' on server '{_N8N_MCP_SERVER_NAME}' " + "not found via Agent Gateway." + ) + + # 3. Execute workflow + message_body = message.model_dump(mode="json") if message is not None else {} + execute_arguments = { + "workflowId": hook.n8n_workflow_config.workflow_id, + "inputs": { + "type": "webhook", + "webhookData": { + "method": hook.n8n_workflow_config.method, + "query": {}, + "body": message_body, + "headers": headers or {}, + }, + }, + } + try: + result_str = await agw_client.call_mcp_tool( + execute_tool, + user_token=user_token or None, + **execute_arguments, + ) + except Exception as exc: + raise TransportError( + f"AGW tool call for '{_EXECUTE_WORKFLOW_TOOL_NAME}' failed: {exc}" + ) from exc + + try: + data = json.loads(result_str) + except Exception as exc: + raise TransportError(f"Could not parse hook response: {exc}") from exc + + status = data.get("status", "") + + if status in _EXECUTE_TERMINAL_STATUSES: + error_msg = data.get("error", "") + raise TransportError( + f"Workflow execution failed with status {status!r}" + + (f": {error_msg}" if error_msg else "") + ) + + if status == "success": + try: + result_data = data.get("data", {}).get("resultData", {}) + last_node = result_data.get("lastNodeExecuted", "") + response_json = ( + result_data.get("runData", {}) + .get(last_node, [{}])[0] + .get("data", {}) + .get("main", [[{}]])[0][0] + .get("json", {}) + ) + return Message(**response_json) + except (KeyError, IndexError, TypeError, ValidationError) as exc: + raise TransportError( + f"Failed to extract response from last executed node: {exc}" + ) from exc + + # 4. Poll get_execution for running/new/waiting/started + execution_id = data.get("executionId") + deadline = time.monotonic() + hook.timeout + last_status = status + + while time.monotonic() < deadline: + await asyncio.sleep(_HOOK_POLL_INTERVAL) + + try: + get_execution_arguments = { + "workflowId": hook.n8n_workflow_config.workflow_id, + "executionId": str(execution_id), + "includeData": True, + } + result_str = await agw_client.call_mcp_tool( + get_exec_tool, + user_token=user_token or None, + **get_execution_arguments, + ) + except Exception as exc: + raise TransportError( + f"AGW tool call for '{_GET_EXECUTION_TOOL_NAME}' failed: {exc}" + ) from exc + + try: + data = json.loads(result_str) + except Exception as exc: + raise TransportError(f"Could not parse hook response: {exc}") from exc last_status = data.get("execution", {}).get("status", "") or data.get( "status", "" @@ -404,8 +609,6 @@ def call_hook( + (f": {error_msg}" if error_msg else "") ) - # Continue polling for: running, waiting, new, unknown - raise TransportError( f"Workflow execution timed out after {hook.timeout}s. " f"Last status: {last_status!r}" diff --git a/tests/extensibility/unit/test_client.py b/tests/extensibility/unit/test_client.py index ef409e5d..f5c449a6 100644 --- a/tests/extensibility/unit/test_client.py +++ b/tests/extensibility/unit/test_client.py @@ -1,10 +1,17 @@ """Tests for ExtensibilityClient and create_client.""" -from unittest.mock import MagicMock, patch +import json +from unittest.mock import AsyncMock, MagicMock, patch +import pytest from sap_cloud_sdk.extensibility import create_client -from sap_cloud_sdk.extensibility.client import ExtensibilityClient +from sap_cloud_sdk.extensibility.client import ( + ExtensibilityClient, + _EXECUTE_WORKFLOW_TOOL_NAME, + _GET_EXECUTION_TOOL_NAME, + _N8N_MCP_SERVER_NAME, +) from sap_cloud_sdk.extensibility._models import ( ExtensionCapabilityImplementation, McpServer, @@ -18,6 +25,7 @@ from http import HTTPMethod from sap_cloud_sdk.extensibility.config import ExtensibilityConfig from sap_cloud_sdk.extensibility.exceptions import TransportError +from sap_cloud_sdk.agentgateway._models import MCPTool class TestCreateClient: @@ -206,3 +214,348 @@ def test_error_logging(self): client.get_extension_capability_implementation(tenant=_TENANT) mock_logger.error.assert_called_once() assert "Failed to retrieve" in mock_logger.error.call_args[0][0] + + +# --------------------------------------------------------------------------- +# Helpers shared across call_hook tests +# --------------------------------------------------------------------------- + +def _make_hook(workflow_id: str = "wf-001", timeout: int = 30) -> Hook: + return Hook( + hook_id="agent_pre_hook", + id="9f6e5f66-7e4f-4ef0-a9f6-e6e1c1220c11", + n8n_workflow_config=N8nWorkflowConfig( + workflow_id=workflow_id, + method=HTTPMethod.POST, + ), + name="Pre Hook", + type=HookType.BEFORE, + deployment_type=DeploymentType.N8N, + timeout=timeout, + execution_mode=ExecutionMode.SYNC, + on_failure=OnFailure.CONTINUE, + order=0, + can_short_circuit=True, + ) + + +def _make_n8n_tool(name: str) -> MCPTool: + """Return an MCPTool belonging to the N8N MCP server.""" + return MCPTool( + name=name, + server_name=_N8N_MCP_SERVER_NAME, + description="", + input_schema={}, + url="https://agw.example.com/v1/mcp/sap.btpn8n:apiResource:ManagedN8nMcpServer:v1/gtid-1", + ) + + +def _make_other_server_tool(name: str) -> MCPTool: + """Return an MCPTool with the same name but from a different MCP server.""" + return MCPTool( + name=name, + server_name="sap.other:apiResource:OtherMcpServer:v1", + description="", + input_schema={}, + url="https://agw.example.com/v1/mcp/other/gtid-2", + ) + + +def _success_payload(workflow_id: str = "wf-001") -> str: + return json.dumps({ + "status": "success", + "data": { + "resultData": { + "lastNodeExecuted": "Respond to Webhook", + "runData": { + "Respond to Webhook": [ + { + "data": { + "main": [ + [ + { + "json": { + "message_id": "msg-1", + "context_id": "ctx-1", + "role": 2, + } + } + ] + ] + } + } + ] + }, + } + }, + }) + + +def _running_payload(execution_id: str = "exec-1") -> str: + return json.dumps({"status": "running", "executionId": execution_id}) + + +def _poll_success_payload() -> str: + return json.dumps({ + "status": "success", + "data": { + "resultData": { + "lastNodeExecuted": "Respond to Webhook", + "runData": { + "Respond to Webhook": [ + { + "data": { + "main": [ + [ + { + "json": { + "message_id": "msg-2", + "context_id": "ctx-1", + "role": 2, + } + } + ] + ] + } + } + ] + }, + } + }, + }) + + +def _make_agw_client(tools: list, tool_responses: list) -> MagicMock: + """Build a mock AgentGatewayClient with preset list_mcp_tools and call_mcp_tool results.""" + agw = MagicMock() + agw.list_mcp_tools = AsyncMock(return_value=tools) + agw.call_mcp_tool = AsyncMock(side_effect=tool_responses) + return agw + + +# --------------------------------------------------------------------------- +# Tests for ExtensibilityClient.call_hook +# --------------------------------------------------------------------------- + + +class TestCallHook: + """Tests for ExtensibilityClient.call_hook (async, AGW-based).""" + + def _make_client(self, agw: MagicMock) -> ExtensibilityClient: + """Build an ExtensibilityClient with a mock transport and patched AGW factory.""" + client = ExtensibilityClient(MagicMock()) + # Stash the agw on the instance for the patcher closure to return. + client._test_agw = agw # type: ignore[attr-defined] + return client + + @pytest.mark.asyncio + async def test_execute_tool_not_found_raises(self): + """Raises TransportError when execute_workflow tool is absent.""" + agw = _make_agw_client(tools=[], tool_responses=[]) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + with pytest.raises(TransportError, match=_EXECUTE_WORKFLOW_TOOL_NAME): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_get_exec_tool_not_found_raises(self): + """Raises TransportError when get_execution tool is absent.""" + tools = [_make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME)] + agw = _make_agw_client(tools=tools, tool_responses=[]) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + with pytest.raises(TransportError, match=_GET_EXECUTION_TOOL_NAME): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_composite_key_ignores_wrong_server(self): + """Tools from a different server with the same names must not match.""" + tools = [ + _make_other_server_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_other_server_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = _make_agw_client(tools=tools, tool_responses=[]) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + with pytest.raises(TransportError, match=_EXECUTE_WORKFLOW_TOOL_NAME): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_composite_key_picks_correct_tool_among_duplicates(self): + """Picks the N8N tool when another server exposes identically-named tools.""" + tools = [ + _make_other_server_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_other_server_tool(_GET_EXECUTION_TOOL_NAME), + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = _make_agw_client( + tools=tools, + tool_responses=[_success_payload()], + ) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + result = await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + assert result is not None + # call_mcp_tool must have been called with the N8N tool, not the other one + called_tool = agw.call_mcp_tool.call_args[0][0] + assert called_tool.server_name == _N8N_MCP_SERVER_NAME + + @pytest.mark.asyncio + async def test_success_synchronous(self): + """Returns a Message when execute_workflow responds with status=success.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = _make_agw_client( + tools=tools, + tool_responses=[_success_payload()], + ) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + result = await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + assert result is not None + assert result.message_id == "msg-1" + agw.call_mcp_tool.assert_called_once() + + @pytest.mark.asyncio + async def test_success_after_polling(self): + """Returns a Message after one poll round via get_execution.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = _make_agw_client( + tools=tools, + tool_responses=[_running_payload(), _poll_success_payload()], + ) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ), patch( + "sap_cloud_sdk.extensibility.client.asyncio.sleep", + new_callable=AsyncMock, + ): + result = await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + assert result is not None + assert result.message_id == "msg-2" + assert agw.call_mcp_tool.call_count == 2 + + @pytest.mark.asyncio + async def test_terminal_status_from_execute_raises(self): + """Raises TransportError on a terminal status from execute_workflow.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + terminal_payload = json.dumps({"status": "error", "error": "workflow crashed"}) + agw = _make_agw_client(tools=tools, tool_responses=[terminal_payload]) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + with pytest.raises(TransportError, match="workflow crashed"): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_terminal_status_from_poll_raises(self): + """Raises TransportError on a terminal status from get_execution poll.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + poll_terminal = json.dumps({"status": "error", "error": "node failed"}) + agw = _make_agw_client( + tools=tools, + tool_responses=[_running_payload(), poll_terminal], + ) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ), patch( + "sap_cloud_sdk.extensibility.client.asyncio.sleep", + new_callable=AsyncMock, + ): + with pytest.raises(TransportError, match="node failed"): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_timeout_raises(self): + """Raises TransportError when deadline is exceeded without a success status.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + # Always returns "running" so the loop never exits via success/terminal + agw = _make_agw_client( + tools=tools, + tool_responses=[_running_payload()] + [_running_payload()] * 100, + ) + client = self._make_client(agw) + # Use a hook with timeout=0 so monotonic deadline is immediately exceeded + hook = _make_hook(timeout=0) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ), patch( + "sap_cloud_sdk.extensibility.client.asyncio.sleep", + new_callable=AsyncMock, + ): + with pytest.raises(TransportError, match="timed out"): + await client.call_hook(hook=hook, tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_agw_call_mcp_tool_exception_raises_transport_error(self): + """Wraps call_mcp_tool exceptions in TransportError.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = MagicMock() + agw.list_mcp_tools = AsyncMock(return_value=tools) + agw.call_mcp_tool = AsyncMock(side_effect=RuntimeError("network error")) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + with pytest.raises(TransportError, match="network error"): + await client.call_hook(hook=_make_hook(), tenant_subdomain="t") + + @pytest.mark.asyncio + async def test_workflow_id_passed_to_execute_tool(self): + """Verifies the correct workflowId is forwarded to call_mcp_tool.""" + tools = [ + _make_n8n_tool(_EXECUTE_WORKFLOW_TOOL_NAME), + _make_n8n_tool(_GET_EXECUTION_TOOL_NAME), + ] + agw = _make_agw_client(tools=tools, tool_responses=[_success_payload("wf-xyz")]) + client = self._make_client(agw) + with patch( + "sap_cloud_sdk.extensibility.client.create_agw_client", + return_value=agw, + ): + await client.call_hook( + hook=_make_hook(workflow_id="wf-xyz"), tenant_subdomain="t" + ) + kwargs = agw.call_mcp_tool.call_args[1] + assert kwargs["workflowId"] == "wf-xyz"