Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions docs/source/reference/llms_envs.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
:orphan:

.. _llm_envs:

.. currentmodule:: torchrl.envs.llm

LLM Environments
Expand Down Expand Up @@ -159,3 +161,69 @@ trades rich display for portability.
ReplError
JupyterRepl
SubprocessRepl

Built-in tools and adapters
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchrl.envs.llm.agentic

.. autosummary::
:toctree: generated/
:template: rl_template.rst

ToolCompose
DispatchResult
PythonTool
ShellTool
FileReadTool
StopTool
HttpTool
MCPServerConfig
MCPToolset
RateLimiter
as_tool

Migration from legacy tool transforms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Existing code built on :mod:`torchrl.envs.llm.transforms` keeps working:
no ``DeprecationWarning`` is emitted in this release. Each legacy class
has a ``.. seealso::`` block in its docstring pointing at the
recommended replacement, summarised here.

.. list-table:: Legacy transform → agentic counterpart
:header-rows: 1
:widths: 30 30 40

* - Legacy
- Agentic
- Adapter recipe
* - ``ExecuteToolsInOrder``
- :class:`ToolCompose`
- Replace at the env stack level. ``ToolCompose`` runs calls
concurrently; pin sequential execution per-tool with
:class:`RateLimiter` ``max_concurrent=1`` if you depend on
ordering.
* - ``PythonInterpreter``
- :class:`PythonTool` + :class:`Sandbox` + :class:`Repl`
- For a soft migration, lift the existing transform: ``as_tool(PythonInterpreter(persistent=True), name="python", input_schema=...)``.
* - ``SimpleToolTransform``
- Native :class:`Tool` subclass
- Or ``as_tool(transform, name=..., input_schema=...)``.
* - ``BrowserTransform``
- :func:`tools.as_tool` of the existing transform
- A native :class:`Tool` for browser automation may land later;
until then the adapter is the recommended path.
* - ``MCPToolTransform``
- :class:`MCPToolset`
- One :class:`Tool` per remote tool, schemas auto-discovered.
Drops directly into ``ToolCompose``.
* - ``XMLBlockParser`` / ``JSONCallParser``
- :class:`parsers.XMLToolCallParser` / :class:`parsers.JSONToolCallParser`
- Same syntax; the agentic versions enforce a stable ``call_id``.
* - ``ToolService`` / ``ToolRegistry``
- The ``tools=[...]`` argument to :class:`ToolCompose`
- The registry pattern collapses into the compose container.

For a guided walkthrough, see the
:ref:`agentic ChatEnv tutorial <llm_agentic>`.
50 changes: 49 additions & 1 deletion test/llm/test_agentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
)
from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap
from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import _has_sandbox_exec
from torchrl.envs.llm.agentic.tools import as_tool
from torchrl.envs.llm.agentic.tools import as_tool, HttpTool
from torchrl.envs.llm.agentic.tools.mcp import _has_mcp, MCPServerConfig
from torchrl.envs.llm.transforms import IncrementalTokenizer


Expand Down Expand Up @@ -806,3 +807,50 @@ def _process_batch_item(self, content, index):
# The last appended message should be the tool result containing
# the legacy output.
assert "legacy got" in prompt[0][-1].content


# ----- MCP and HTTP tools -----


class TestMCPToolset:
def test_construction_requires_mcp_package(self):
if not _has_mcp:
with pytest.raises(ImportError):
from torchrl.envs.llm.agentic.tools import MCPToolset

MCPToolset(MCPServerConfig(command="true"))
else:
# When the package is installed we can at least construct
# without opening a session.
from torchrl.envs.llm.agentic.tools import MCPToolset

pool = MCPToolset(MCPServerConfig(command="true"))
assert pool.tools == ()

def test_server_config(self):
cfg = MCPServerConfig(command="npx", args=("@browsermcp/mcp@latest",))
assert cfg.command == "npx"
assert cfg.args == ("@browsermcp/mcp@latest",)


class TestHttpTool:
def test_blocks_disallowed_host(self):
async def go():
tool = HttpTool(allowed_hosts=("api.example.com",))
await tool.setup()
res = await tool.run(
{"url": "https://other-host.example/foo"},
ToolContext(call_id="c"),
)
assert res.is_error
assert "allowed_hosts" in res.text
await tool.teardown()

_run(go())

def test_protocol_conformance(self):
tool = HttpTool()
# Sanity: it walks like a Tool.
assert tool.name == "http"
assert callable(tool.run)
assert "url" in tool.input_schema["properties"]
6 changes: 6 additions & 0 deletions torchrl/envs/llm/agentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from .schema import json_schema_from_pydantic, validate_args
from .tools import (
FileReadTool,
HttpTool,
MCPServerConfig,
MCPToolset,
PythonTool,
ShellTool,
StopSignal,
Expand All @@ -46,8 +49,11 @@
"DispatchResult",
"FileReadTool",
"FileRefPart",
"HttpTool",
"ImagePart",
"JsonPart",
"MCPServerConfig",
"MCPToolset",
"ParseResult",
"ParsedCall",
"PythonTool",
Expand Down
5 changes: 5 additions & 0 deletions torchrl/envs/llm/agentic/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
from __future__ import annotations

from .builtin import FileReadTool, PythonTool, ShellTool, StopTool, StopSignal
from .http import HttpTool
from .legacy_adapter import as_tool
from .mcp import MCPServerConfig, MCPToolset

__all__ = [
"FileReadTool",
"HttpTool",
"MCPServerConfig",
"MCPToolset",
"PythonTool",
"ShellTool",
"StopSignal",
Expand Down
154 changes: 154 additions & 0 deletions torchrl/envs/llm/agentic/tools/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""HTTP tool with built-in rate limiting.

Async-only via :mod:`urllib.request` offloaded to a worker thread (so we
don't introduce ``aiohttp``/``httpx`` as a hard dependency). For
production agentic workloads users should pair this with a
:class:`~torchrl.envs.llm.agentic.RateLimiter` keyed on the tool name in
the parent :class:`~torchrl.envs.llm.agentic.ToolCompose`.
"""
from __future__ import annotations

import asyncio
import json
from collections.abc import Mapping
from typing import Any, ClassVar
from urllib import error as urllib_error
from urllib import request as urllib_request

from ..protocols import TextPart, ToolContext, ToolError, ToolResult


_DEFAULT_MAX_BYTES = 1 << 20 # 1 MiB


class HttpTool:
"""Make an HTTP request and return the response body.

Args:
allowed_hosts: If non-empty, requests to hosts not in this set
raise :class:`ToolError`. Use ``("api.openai.com",)`` style.
Empty disables the check (use only with a stronger
sandbox/network policy upstream).
timeout: Per-request timeout (seconds).
max_response_bytes: Cap on the returned body. Larger responses
are truncated with a marker.

Examples:
>>> from torchrl.envs.llm.agentic.tools.http import HttpTool
>>> tool = HttpTool(allowed_hosts=("api.example.com",))
"""

name: ClassVar[str] = "http"
description: ClassVar[str] = (
"Make an HTTP request. Returns body, headers, status."
)
input_schema: ClassVar[Mapping[str, Any]] = {
"type": "object",
"properties": {
"url": {"type": "string"},
"method": {"type": "string"}, # default GET
"headers": {"type": "object"},
"body": {"type": "string"},
},
"required": ["url"],
}
output_schema: ClassVar[Mapping[str, Any] | None] = None
wants_state: ClassVar[bool] = False

def __init__(
self,
*,
allowed_hosts: tuple[str, ...] = (),
timeout: float = 10.0,
max_response_bytes: int = _DEFAULT_MAX_BYTES,
) -> None:
self.allowed_hosts = tuple(allowed_hosts)
self.timeout = timeout
self.max_response_bytes = max_response_bytes

async def setup(self) -> None:
pass

async def teardown(self) -> None:
pass

async def run(
self, args: Mapping[str, Any], ctx: ToolContext
) -> ToolResult:
url = args["url"]
method = (args.get("method") or "GET").upper()
headers = dict(args.get("headers") or {})
body = args.get("body")
if self.allowed_hosts:
host = _host_of(url)
if host not in self.allowed_hosts:
return ToolResult(
parts=(TextPart(
text=(
f"host {host!r} not in allowed_hosts "
f"{self.allowed_hosts!r}"
),
),),
is_error=True,
meta={"blocked_host": host},
)
data = body.encode("utf-8") if isinstance(body, str) else body
try:
status, resp_body, resp_headers = await asyncio.to_thread(
_do_request, url, method, headers, data, self.timeout,
self.max_response_bytes,
)
except urllib_error.HTTPError as e:
return ToolResult(
parts=(TextPart(text=f"HTTP {e.code}: {e.reason}"),),
is_error=True,
meta={"status": e.code},
)
except urllib_error.URLError as e:
return ToolResult(
parts=(TextPart(text=f"URL error: {e.reason}"),),
is_error=True,
meta={"error": str(e.reason)},
)
text = resp_body.decode("utf-8", errors="replace")
truncated = len(resp_body) >= self.max_response_bytes
if truncated:
text += "\n... [truncated]"
return ToolResult(
parts=(TextPart(text=text),),
is_error=status >= 400,
meta={
"status": status,
"headers": dict(resp_headers),
"truncated": truncated,
},
)


def _host_of(url: str) -> str:
from urllib.parse import urlparse

return urlparse(url).hostname or ""


def _do_request(
url: str,
method: str,
headers: Mapping[str, str],
data: bytes | None,
timeout: float,
max_bytes: int,
) -> tuple[int, bytes, Mapping[str, str]]:
req = urllib_request.Request(
url, data=data, headers=dict(headers), method=method
)
with urllib_request.urlopen(req, timeout=timeout) as resp:
body = resp.read(max_bytes)
return resp.status, body, dict(resp.headers.items())


__all__ = ["HttpTool"]
Loading
Loading