diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 110a1263d..257dd3521 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -30,7 +30,7 @@ from .protocol import Protocol, ToolSchema from .tool import ToolboxTool from .toolbox_transport import ToolboxTransport -from .utils import identify_auth_requirements, resolve_value +from .utils import identify_auth_requirements, resolve_value, warn_if_http_and_headers class ToolboxClient: @@ -101,6 +101,7 @@ def __init__( raise ValueError(f"Unsupported MCP protocol version: {protocol}") self.__client_headers = client_headers if client_headers is not None else {} + warn_if_http_and_headers(url, self.__client_headers) def __parse_tool( self, @@ -224,6 +225,8 @@ async def load_tool( for name, val in self.__client_headers.items() } + warn_if_http_and_headers(self.__transport.base_url, auth_token_getters) + manifest = await self.__transport.tool_get(name, resolved_headers) # parse the provided definition to a tool @@ -299,6 +302,8 @@ async def load_toolset( for header_name in original_headers } + warn_if_http_and_headers(self.__transport.base_url, auth_token_getters) + manifest = await self.__transport.tools_list(name, resolved_headers) tools: list[ToolboxTool] = [] diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 0c72c39b3..ecc3f568a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -18,7 +18,6 @@ from inspect import Signature from types import MappingProxyType from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union -from warnings import warn from .itransport import ITransport from .protocol import ParameterSchema @@ -27,6 +26,7 @@ identify_auth_requirements, params_to_pydantic_model, resolve_value, + warn_if_http_and_headers, ) @@ -272,6 +272,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: token_getter ) + warn_if_http_and_headers(self.__transport.base_url, headers) + return await self.__transport.tool_invoke( self.__name__, payload, diff --git a/packages/toolbox-core/src/toolbox_core/toolbox_transport.py b/packages/toolbox-core/src/toolbox_core/toolbox_transport.py index 0f1e7e40d..3f83d8306 100644 --- a/packages/toolbox-core/src/toolbox_core/toolbox_transport.py +++ b/packages/toolbox-core/src/toolbox_core/toolbox_transport.py @@ -13,12 +13,12 @@ # limitations under the License. from typing import Mapping, Optional -from warnings import warn from aiohttp import ClientSession from .itransport import ITransport from .protocol import ManifestSchema +from .utils import warn_if_http_and_headers class ToolboxTransport(ITransport): @@ -70,14 +70,6 @@ async def tools_list( async def tool_invoke( self, tool_name: str, arguments: dict, headers: Mapping[str, str] ) -> str: - # ID tokens contain sensitive user information (claims). Transmitting - # these over HTTP exposes the data to interception and unauthorized - # access. Always use HTTPS to ensure secure communication and protect - # user privacy. - if self.base_url.startswith("http://") and headers: - warn( - "Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) url = f"{self.__base_url}/api/tool/{tool_name}/invoke" async with self.__session.post( url, diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 08a87a451..1f34c4bef 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -14,6 +14,7 @@ import asyncio +import warnings from typing import ( Any, Awaitable, @@ -43,6 +44,14 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - return docstring +def warn_if_http_and_headers(url: str, headers: Mapping[str, Any] | None) -> None: + """Logs a warning if the url uses HTTP and sensitive headers are present.""" + if url.lower().startswith("http://") and headers: + warnings.warn( + "This connection is using HTTP. To prevent credential exposure, please ensure all communication is sent over HTTPS." + ) + + def identify_auth_requirements( req_authn_params: Mapping[str, list[str]], req_authz_tokens: Sequence[str], diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index 822cc6b9e..0f0286af8 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -685,3 +685,68 @@ async def test_bind_param_chaining( json={"count": 42, "message": "chained-call"}, headers={}, ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "base_url, headers, should_warn", + [ + ( + "http://fake-toolbox-server.com", + {"Authorization": "Bearer token"}, + True, + ), + ( + "https://fake-toolbox-server.com", + {"Authorization": "Bearer token"}, + False, + ), + ("http://fake-toolbox-server.com", {}, False), + ("http://fake-toolbox-server.com", None, False), + ], +) +async def test_tool_call_http_warning( + http_session: ClientSession, + base_url: str, + headers: Mapping[str, str] | None, + should_warn: bool, +): + """Tests the HTTP security warning logic during tool invocation via __call__.""" + url = f"{base_url}/api/tool/{TEST_TOOL_NAME}/invoke" + args = {"param1": "value1"} + response_payload = {"result": "success"} + transport = ToolboxTransport(base_url, http_session) + + tool = ToolboxTool( + transport=transport, + name=TEST_TOOL_NAME, + description="A tool", + params=[ + ParameterSchema(name="param1", type="string", description="param1 desc") + ], + required_authn_params={}, + required_authz_tokens=[], + auth_service_token_getters={}, + bound_params={}, + client_headers=headers if headers is not None else {}, + ) + + with aioresponses() as m: + m.post(url, status=200, payload=response_payload) + + if should_warn: + with pytest.warns( + UserWarning, + match="This connection is using HTTP. To prevent credential exposure, please ensure all communication is sent over HTTPS.", + ): + await tool(param1="value1") + else: + # Check no warnings fired + with catch_warnings(record=True) as record: + simplefilter("always") + await tool(param1="value1") + + warning_messages = [str(w.message) for w in record] + assert not any( + "This connection is using HTTP" in msg for msg in warning_messages + ) diff --git a/packages/toolbox-core/tests/test_toolbox_transport.py b/packages/toolbox-core/tests/test_toolbox_transport.py index 2921e09f8..84944274f 100644 --- a/packages/toolbox-core/tests/test_toolbox_transport.py +++ b/packages/toolbox-core/tests/test_toolbox_transport.py @@ -154,48 +154,6 @@ async def test_tool_invoke_failure(http_session: ClientSession): assert str(exc_info.value) == "Invalid arguments" -@pytest.mark.asyncio -@pytest.mark.parametrize( - "base_url, headers, should_warn", - [ - ( - "http://fake-toolbox-server.com", - {"Authorization": "Bearer token"}, - True, - ), - ( - "https://fake-toolbox-server.com", - {"Authorization": "Bearer token"}, - False, - ), - ("http://fake-toolbox-server.com", {}, False), - ("http://fake-toolbox-server.com", None, False), - ], -) -async def test_tool_invoke_http_warning( - http_session: ClientSession, - base_url: str, - headers: Optional[Mapping[str, str]], - should_warn: bool, -): - """Tests the HTTP security warning logic in tool_invoke.""" - url = f"{base_url}/api/tool/{TEST_TOOL_NAME}/invoke" - args = {"param1": "value1"} - response_payload = {"result": "success"} - transport = ToolboxTransport(base_url, http_session) - - with aioresponses() as m: - m.post(url, status=200, payload=response_payload) - - if should_warn: - with pytest.warns(UserWarning, match="Sending data token over HTTP"): - await transport.tool_invoke(TEST_TOOL_NAME, args, headers) - else: - # By not using pytest.warns, we assert that no warnings are raised. - # The test will fail if an unexpected UserWarning occurs. - await transport.tool_invoke(TEST_TOOL_NAME, args, headers) - - @pytest.mark.asyncio async def test_close_does_not_close_unmanaged_session(): """ diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index b3ddd7c33..1f54ae0dc 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -14,6 +14,7 @@ import asyncio +import warnings from typing import Type from unittest.mock import Mock @@ -26,6 +27,7 @@ identify_auth_requirements, params_to_pydantic_model, resolve_value, + warn_if_http_and_headers, ) @@ -458,3 +460,31 @@ async def another_async_func(): return {"key": "value"} assert await resolve_value(another_async_func) == {"key": "value"} + + +def test_warn_if_http_and_headers_triggers(): + """Test that a warning is emitted for HTTP URLs with headers.""" + url = "http://example.com" + headers = {"Authorization": "Bearer token"} + with pytest.warns(UserWarning, match="This connection is using HTTP"): + warn_if_http_and_headers(url, headers) + + +def test_warn_if_http_and_headers_no_headers(): + """Test that no warning is emitted for HTTP URLs without headers.""" + url = "http://example.com" + headers = {} + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_http_and_headers(url, headers) + assert len(w) == 0 + + +def test_warn_if_http_and_headers_https(): + """Test that no warning is emitted for HTTPS URLs.""" + url = "https://example.com" + headers = {"Authorization": "Bearer token"} + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_http_and_headers(url, headers) + assert len(w) == 0