Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .protocol import Protocol, ToolSchema
from .tool import ToolboxTool
from .toolbox_transport import ToolboxTransport
from .utils import identify_auth_requirements, resolve_value
from .utils import identify_auth_requirements, resolve_value, warn_if_http_and_headers


class ToolboxClient:
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
raise ValueError(f"Unsupported MCP protocol version: {protocol}")

self.__client_headers = client_headers if client_headers is not None else {}
warn_if_http_and_headers(url, self.__client_headers)

def __parse_tool(
self,
Expand Down Expand Up @@ -224,6 +225,8 @@ async def load_tool(
for name, val in self.__client_headers.items()
}

warn_if_http_and_headers(self.__transport.base_url, auth_token_getters)

manifest = await self.__transport.tool_get(name, resolved_headers)

# parse the provided definition to a tool
Expand Down Expand Up @@ -299,6 +302,8 @@ async def load_toolset(
for header_name in original_headers
}

warn_if_http_and_headers(self.__transport.base_url, auth_token_getters)

manifest = await self.__transport.tools_list(name, resolved_headers)

tools: list[ToolboxTool] = []
Expand Down
4 changes: 3 additions & 1 deletion packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from inspect import Signature
from types import MappingProxyType
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
from warnings import warn

from .itransport import ITransport
from .protocol import ParameterSchema
Expand All @@ -27,6 +26,7 @@
identify_auth_requirements,
params_to_pydantic_model,
resolve_value,
warn_if_http_and_headers,
)


Expand Down Expand Up @@ -272,6 +272,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
token_getter
)

warn_if_http_and_headers(self.__transport.base_url, headers)

return await self.__transport.tool_invoke(
self.__name__,
payload,
Expand Down
10 changes: 1 addition & 9 deletions packages/toolbox-core/src/toolbox_core/toolbox_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

from typing import Mapping, Optional
from warnings import warn

from aiohttp import ClientSession

from .itransport import ITransport
from .protocol import ManifestSchema
from .utils import warn_if_http_and_headers


class ToolboxTransport(ITransport):
Expand Down Expand Up @@ -70,14 +70,6 @@ async def tools_list(
async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> str:
# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if self.base_url.startswith("http://") and headers:
warn(
"Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)
url = f"{self.__base_url}/api/tool/{tool_name}/invoke"
async with self.__session.post(
url,
Expand Down
9 changes: 9 additions & 0 deletions packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
import warnings
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -43,6 +44,14 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
return docstring


def warn_if_http_and_headers(url: str, headers: Mapping[str, Any] | None) -> None:
"""Logs a warning if the url uses HTTP and sensitive headers are present."""
if url.lower().startswith("http://") and headers:
warnings.warn(
"This connection is using HTTP. To prevent credential exposure, please ensure all communication is sent over HTTPS."
)


def identify_auth_requirements(
req_authn_params: Mapping[str, list[str]],
req_authz_tokens: Sequence[str],
Expand Down
65 changes: 65 additions & 0 deletions packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,68 @@ async def test_bind_param_chaining(
json={"count": 42, "message": "chained-call"},
headers={},
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url, headers, should_warn",
[
(
"http://fake-toolbox-server.com",
{"Authorization": "Bearer token"},
True,
),
(
"https://fake-toolbox-server.com",
{"Authorization": "Bearer token"},
False,
),
("http://fake-toolbox-server.com", {}, False),
("http://fake-toolbox-server.com", None, False),
],
)
async def test_tool_call_http_warning(
http_session: ClientSession,
base_url: str,
headers: Mapping[str, str] | None,
should_warn: bool,
):
"""Tests the HTTP security warning logic during tool invocation via __call__."""
url = f"{base_url}/api/tool/{TEST_TOOL_NAME}/invoke"
args = {"param1": "value1"}
response_payload = {"result": "success"}
transport = ToolboxTransport(base_url, http_session)

tool = ToolboxTool(
transport=transport,
name=TEST_TOOL_NAME,
description="A tool",
params=[
ParameterSchema(name="param1", type="string", description="param1 desc")
],
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers=headers if headers is not None else {},
)

with aioresponses() as m:
m.post(url, status=200, payload=response_payload)

if should_warn:
with pytest.warns(
UserWarning,
match="This connection is using HTTP. To prevent credential exposure, please ensure all communication is sent over HTTPS.",
):
await tool(param1="value1")
else:
# Check no warnings fired
with catch_warnings(record=True) as record:
simplefilter("always")
await tool(param1="value1")

warning_messages = [str(w.message) for w in record]
assert not any(
"This connection is using HTTP" in msg for msg in warning_messages
)
42 changes: 0 additions & 42 deletions packages/toolbox-core/tests/test_toolbox_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,48 +154,6 @@ async def test_tool_invoke_failure(http_session: ClientSession):
assert str(exc_info.value) == "Invalid arguments"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url, headers, should_warn",
[
(
"http://fake-toolbox-server.com",
{"Authorization": "Bearer token"},
True,
),
(
"https://fake-toolbox-server.com",
{"Authorization": "Bearer token"},
False,
),
("http://fake-toolbox-server.com", {}, False),
("http://fake-toolbox-server.com", None, False),
],
)
async def test_tool_invoke_http_warning(
http_session: ClientSession,
base_url: str,
headers: Optional[Mapping[str, str]],
should_warn: bool,
):
"""Tests the HTTP security warning logic in tool_invoke."""
url = f"{base_url}/api/tool/{TEST_TOOL_NAME}/invoke"
args = {"param1": "value1"}
response_payload = {"result": "success"}
transport = ToolboxTransport(base_url, http_session)

with aioresponses() as m:
m.post(url, status=200, payload=response_payload)

if should_warn:
with pytest.warns(UserWarning, match="Sending data token over HTTP"):
await transport.tool_invoke(TEST_TOOL_NAME, args, headers)
else:
# By not using pytest.warns, we assert that no warnings are raised.
# The test will fail if an unexpected UserWarning occurs.
await transport.tool_invoke(TEST_TOOL_NAME, args, headers)


@pytest.mark.asyncio
async def test_close_does_not_close_unmanaged_session():
"""
Expand Down
30 changes: 30 additions & 0 deletions packages/toolbox-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
import warnings
from typing import Type
from unittest.mock import Mock

Expand All @@ -26,6 +27,7 @@
identify_auth_requirements,
params_to_pydantic_model,
resolve_value,
warn_if_http_and_headers,
)


Expand Down Expand Up @@ -458,3 +460,31 @@ async def another_async_func():
return {"key": "value"}

assert await resolve_value(another_async_func) == {"key": "value"}


def test_warn_if_http_and_headers_triggers():
"""Test that a warning is emitted for HTTP URLs with headers."""
url = "http://example.com"
headers = {"Authorization": "Bearer token"}
with pytest.warns(UserWarning, match="This connection is using HTTP"):
warn_if_http_and_headers(url, headers)


def test_warn_if_http_and_headers_no_headers():
"""Test that no warning is emitted for HTTP URLs without headers."""
url = "http://example.com"
headers = {}
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_http_and_headers(url, headers)
assert len(w) == 0


def test_warn_if_http_and_headers_https():
"""Test that no warning is emitted for HTTPS URLs."""
url = "https://example.com"
headers = {"Authorization": "Bearer token"}
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_http_and_headers(url, headers)
assert len(w) == 0
Loading