From 026061eb6a9ce884455e56cbae0c41320b418b9c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 10 May 2026 21:18:04 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/llms_envs.rst | 68 +++++++ test/llm/test_llm_transforms.py | 55 +++++ torchrl/envs/llm/agentic/__init__.py | 6 + torchrl/envs/llm/agentic/tools/__init__.py | 5 + torchrl/envs/llm/agentic/tools/http.py | 154 ++++++++++++++ torchrl/envs/llm/agentic/tools/mcp.py | 221 +++++++++++++++++++++ torchrl/envs/llm/transforms/browser.py | 5 + torchrl/envs/llm/transforms/tools.py | 40 ++++ tutorials/sphinx-tutorials/llm_agentic.py | 218 ++++++++++++++++++++ 9 files changed, 772 insertions(+) create mode 100644 torchrl/envs/llm/agentic/tools/http.py create mode 100644 torchrl/envs/llm/agentic/tools/mcp.py create mode 100644 tutorials/sphinx-tutorials/llm_agentic.py diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst index 879ba3e6d4f..b2ba6bbb81b 100644 --- a/docs/source/reference/llms_envs.rst +++ b/docs/source/reference/llms_envs.rst @@ -1,5 +1,7 @@ :orphan: +.. _llm_envs: + .. currentmodule:: torchrl.envs.llm LLM Environments @@ -136,3 +138,69 @@ trades rich display for portability. ReplError JupyterRepl SubprocessRepl + +Built-in tools and adapters +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.agentic + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + ToolCompose + DispatchResult + PythonTool + ShellTool + FileReadTool + StopTool + HttpTool + MCPServerConfig + MCPToolset + RateLimiter + as_tool + +Migration from legacy tool transforms +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Existing code built on :mod:`torchrl.envs.llm.transforms` keeps working: +no ``DeprecationWarning`` is emitted in this release. Each legacy class +has a ``.. seealso::`` block in its docstring pointing at the +recommended replacement, summarised here. + +.. list-table:: Legacy transform → agentic counterpart + :header-rows: 1 + :widths: 30 30 40 + + * - Legacy + - Agentic + - Adapter recipe + * - ``ExecuteToolsInOrder`` + - :class:`ToolCompose` + - Replace at the env stack level. ``ToolCompose`` runs calls + concurrently; pin sequential execution per-tool with + :class:`RateLimiter` ``max_concurrent=1`` if you depend on + ordering. + * - ``PythonInterpreter`` + - :class:`PythonTool` + :class:`Sandbox` + :class:`Repl` + - For a soft migration, lift the existing transform: ``as_tool(PythonInterpreter(persistent=True), name="python", input_schema=...)``. + * - ``SimpleToolTransform`` + - Native :class:`Tool` subclass + - Or ``as_tool(transform, name=..., input_schema=...)``. + * - ``BrowserTransform`` + - :func:`tools.as_tool` of the existing transform + - A native :class:`Tool` for browser automation may land later; + until then the adapter is the recommended path. + * - ``MCPToolTransform`` + - :class:`MCPToolset` + - One :class:`Tool` per remote tool, schemas auto-discovered. + Drops directly into ``ToolCompose``. + * - ``XMLBlockParser`` / ``JSONCallParser`` + - :class:`parsers.XMLToolCallParser` / :class:`parsers.JSONToolCallParser` + - Same syntax; the agentic versions enforce a stable ``call_id``. + * - ``ToolService`` / ``ToolRegistry`` + - The ``tools=[...]`` argument to :class:`ToolCompose` + - The registry pattern collapses into the compose container. + +For a guided walkthrough, see the +:ref:`agentic ChatEnv tutorial `. diff --git a/test/llm/test_llm_transforms.py b/test/llm/test_llm_transforms.py index f4d48dc4f43..48e7b30b917 100644 --- a/test/llm/test_llm_transforms.py +++ b/test/llm/test_llm_transforms.py @@ -1563,3 +1563,58 @@ def _process_batch_item(self, content, index): # The last appended message should be the tool result containing # the legacy output. assert "legacy got" in prompt[0][-1].content + + +# ----- MCP and HTTP tools ----- + +from torchrl.envs.llm.agentic.tools import HttpTool # noqa: E402 +from torchrl.envs.llm.agentic.tools.mcp import ( # noqa: E402 + MCPServerConfig, + _has_mcp, +) + + +class TestMCPToolset: + def test_construction_requires_mcp_package(self): + if not _has_mcp: + with pytest.raises(ImportError): + from torchrl.envs.llm.agentic.tools import MCPToolset + + MCPToolset(MCPServerConfig(command="true")) + else: + # When the package is installed we can at least construct + # without opening a session. + from torchrl.envs.llm.agentic.tools import MCPToolset + + pool = MCPToolset(MCPServerConfig(command="true")) + assert pool.tools == () + + def test_server_config(self): + cfg = MCPServerConfig( + command="npx", args=("@browsermcp/mcp@latest",) + ) + assert cfg.command == "npx" + assert cfg.args == ("@browsermcp/mcp@latest",) + + +class TestHttpTool: + def test_blocks_disallowed_host(self): + async def go(): + tool = HttpTool(allowed_hosts=("api.example.com",)) + await tool.setup() + res = await tool.run( + {"url": "https://other-host.example/foo"}, + ToolContext(call_id="c"), + ) + assert res.is_error + assert "allowed_hosts" in res.text + await tool.teardown() + + _run(go()) + + def test_protocol_conformance(self): + tool = HttpTool() + # Sanity: it walks like a Tool. + assert tool.name == "http" + assert callable(tool.run) + assert "url" in tool.input_schema["properties"] diff --git a/torchrl/envs/llm/agentic/__init__.py b/torchrl/envs/llm/agentic/__init__.py index db6cbbf47df..614d53e28b8 100644 --- a/torchrl/envs/llm/agentic/__init__.py +++ b/torchrl/envs/llm/agentic/__init__.py @@ -46,6 +46,9 @@ from .schema import json_schema_from_pydantic, validate_args from .tools import ( FileReadTool, + HttpTool, + MCPServerConfig, + MCPToolset, PythonTool, ShellTool, StopSignal, @@ -57,8 +60,11 @@ "DispatchResult", "FileReadTool", "FileRefPart", + "HttpTool", "ImagePart", "JsonPart", + "MCPServerConfig", + "MCPToolset", "ParseResult", "ParsedCall", "PythonTool", diff --git a/torchrl/envs/llm/agentic/tools/__init__.py b/torchrl/envs/llm/agentic/tools/__init__.py index 2795b09c4fa..bae97521545 100644 --- a/torchrl/envs/llm/agentic/tools/__init__.py +++ b/torchrl/envs/llm/agentic/tools/__init__.py @@ -15,10 +15,15 @@ from __future__ import annotations from .builtin import FileReadTool, PythonTool, ShellTool, StopTool, StopSignal +from .http import HttpTool from .legacy_adapter import as_tool +from .mcp import MCPServerConfig, MCPToolset __all__ = [ "FileReadTool", + "HttpTool", + "MCPServerConfig", + "MCPToolset", "PythonTool", "ShellTool", "StopSignal", diff --git a/torchrl/envs/llm/agentic/tools/http.py b/torchrl/envs/llm/agentic/tools/http.py new file mode 100644 index 00000000000..5abb240dacd --- /dev/null +++ b/torchrl/envs/llm/agentic/tools/http.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""HTTP tool with built-in rate limiting. + +Async-only via :mod:`urllib.request` offloaded to a worker thread (so we +don't introduce ``aiohttp``/``httpx`` as a hard dependency). For +production agentic workloads users should pair this with a +:class:`~torchrl.envs.llm.agentic.RateLimiter` keyed on the tool name in +the parent :class:`~torchrl.envs.llm.agentic.ToolCompose`. +""" +from __future__ import annotations + +import asyncio +import json +from collections.abc import Mapping +from typing import Any, ClassVar +from urllib import error as urllib_error +from urllib import request as urllib_request + +from ..protocols import TextPart, ToolContext, ToolError, ToolResult + + +_DEFAULT_MAX_BYTES = 1 << 20 # 1 MiB + + +class HttpTool: + """Make an HTTP request and return the response body. + + Args: + allowed_hosts: If non-empty, requests to hosts not in this set + raise :class:`ToolError`. Use ``("api.openai.com",)`` style. + Empty disables the check (use only with a stronger + sandbox/network policy upstream). + timeout: Per-request timeout (seconds). + max_response_bytes: Cap on the returned body. Larger responses + are truncated with a marker. + + Examples: + >>> from torchrl.envs.llm.agentic.tools.http import HttpTool + >>> tool = HttpTool(allowed_hosts=("api.example.com",)) + """ + + name: ClassVar[str] = "http" + description: ClassVar[str] = ( + "Make an HTTP request. Returns body, headers, status." + ) + input_schema: ClassVar[Mapping[str, Any]] = { + "type": "object", + "properties": { + "url": {"type": "string"}, + "method": {"type": "string"}, # default GET + "headers": {"type": "object"}, + "body": {"type": "string"}, + }, + "required": ["url"], + } + output_schema: ClassVar[Mapping[str, Any] | None] = None + wants_state: ClassVar[bool] = False + + def __init__( + self, + *, + allowed_hosts: tuple[str, ...] = (), + timeout: float = 10.0, + max_response_bytes: int = _DEFAULT_MAX_BYTES, + ) -> None: + self.allowed_hosts = tuple(allowed_hosts) + self.timeout = timeout + self.max_response_bytes = max_response_bytes + + async def setup(self) -> None: + pass + + async def teardown(self) -> None: + pass + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + url = args["url"] + method = (args.get("method") or "GET").upper() + headers = dict(args.get("headers") or {}) + body = args.get("body") + if self.allowed_hosts: + host = _host_of(url) + if host not in self.allowed_hosts: + return ToolResult( + parts=(TextPart( + text=( + f"host {host!r} not in allowed_hosts " + f"{self.allowed_hosts!r}" + ), + ),), + is_error=True, + meta={"blocked_host": host}, + ) + data = body.encode("utf-8") if isinstance(body, str) else body + try: + status, resp_body, resp_headers = await asyncio.to_thread( + _do_request, url, method, headers, data, self.timeout, + self.max_response_bytes, + ) + except urllib_error.HTTPError as e: + return ToolResult( + parts=(TextPart(text=f"HTTP {e.code}: {e.reason}"),), + is_error=True, + meta={"status": e.code}, + ) + except urllib_error.URLError as e: + return ToolResult( + parts=(TextPart(text=f"URL error: {e.reason}"),), + is_error=True, + meta={"error": str(e.reason)}, + ) + text = resp_body.decode("utf-8", errors="replace") + truncated = len(resp_body) >= self.max_response_bytes + if truncated: + text += "\n... [truncated]" + return ToolResult( + parts=(TextPart(text=text),), + is_error=status >= 400, + meta={ + "status": status, + "headers": dict(resp_headers), + "truncated": truncated, + }, + ) + + +def _host_of(url: str) -> str: + from urllib.parse import urlparse + + return urlparse(url).hostname or "" + + +def _do_request( + url: str, + method: str, + headers: Mapping[str, str], + data: bytes | None, + timeout: float, + max_bytes: int, +) -> tuple[int, bytes, Mapping[str, str]]: + req = urllib_request.Request( + url, data=data, headers=dict(headers), method=method + ) + with urllib_request.urlopen(req, timeout=timeout) as resp: + body = resp.read(max_bytes) + return resp.status, body, dict(resp.headers.items()) + + +__all__ = ["HttpTool"] diff --git a/torchrl/envs/llm/agentic/tools/mcp.py b/torchrl/envs/llm/agentic/tools/mcp.py new file mode 100644 index 00000000000..5aa41f1bc46 --- /dev/null +++ b/torchrl/envs/llm/agentic/tools/mcp.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Model Context Protocol (MCP) adapter. + +Connects to an MCP server over stdio and exposes each remote tool as a +native :class:`~torchrl.envs.llm.agentic.Tool`. Unlike the legacy +:class:`~torchrl.envs.llm.transforms.MCPToolTransform`, no background +thread is needed -- our new dispatcher is already async, so we drive +the MCP client coroutines directly. + +Optional dependency: install ``mcp`` (the official Python SDK) to use. +""" +from __future__ import annotations + +import asyncio +import importlib.util +import json +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any + +from torchrl._utils import logger as torchrl_logger + +from ..protocols import TextPart, ToolContext, ToolError, ToolResult + +_has_mcp = importlib.util.find_spec("mcp") is not None + + +@dataclass(frozen=True, slots=True) +class MCPServerConfig: + """How to launch an MCP server over stdio. + + Attributes: + command: Executable, typically ``"npx"`` or ``"uvx"``. + args: Arguments passed to ``command``. + env: Optional environment-variable overrides. + """ + + command: str + args: tuple[str, ...] = () + env: Mapping[str, str] | None = None + + +class MCPToolset: + """Pool of :class:`Tool` instances backed by one MCP server. + + Connect once at :meth:`open` time, the server's ``tools/list`` is + queried and each remote tool becomes a :class:`_MCPTool` exposing + its native schema. Tools share the underlying MCP session for + efficiency. + + Args: + config: How to launch the server. + name_prefix: Optional prefix prepended to every discovered tool + name (e.g. ``"browser_"``). Useful when stacking multiple + servers under one :class:`ToolCompose`. + request_timeout: Default per-call timeout (seconds) forwarded + to the MCP client. + + Example: + >>> import asyncio # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.tools.mcp import ( + ... MCPServerConfig, MCPToolset, + ... ) + >>> async def go(): + ... pool = MCPToolset( + ... MCPServerConfig(command="npx", + ... args=("@browsermcp/mcp@latest",)) + ... ) + ... await pool.open() + ... for tool in pool.tools: + ... print(tool.name, tool.description) + ... await pool.close() + """ + + def __init__( + self, + config: MCPServerConfig, + *, + name_prefix: str = "", + request_timeout: float = 30.0, + ) -> None: + if not _has_mcp: + raise ImportError( + "MCPToolset requires the 'mcp' package. " + "Install with `pip install mcp`." + ) + self.config = config + self.name_prefix = name_prefix + self.request_timeout = request_timeout + self._exit_stack: Any = None + self._session: Any = None + self._tools: list[_MCPTool] = [] + + async def open(self) -> None: + if self._session is not None: + return + from contextlib import AsyncExitStack + + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + + self._exit_stack = AsyncExitStack() + params = StdioServerParameters( + command=self.config.command, + args=list(self.config.args), + env=dict(self.config.env) if self.config.env else None, + ) + read, write = await self._exit_stack.enter_async_context( + stdio_client(params) + ) + session = await self._exit_stack.enter_async_context( + ClientSession(read, write) + ) + await session.initialize() + listed = await session.list_tools() + self._session = session + self._tools = [ + _MCPTool( + session=session, + remote_name=t.name, + description=t.description or "", + input_schema=dict(t.inputSchema or {"type": "object"}), + exposed_name=f"{self.name_prefix}{t.name}", + request_timeout=self.request_timeout, + ) + for t in listed.tools + ] + torchrl_logger.info( + "MCPToolset connected to %s with %d tools", + self.config.command, + len(self._tools), + ) + + async def close(self) -> None: + if self._exit_stack is not None: + try: + await self._exit_stack.aclose() + except Exception: # pragma: no cover -- defensive + torchrl_logger.exception("MCPToolset close raised; continuing") + self._exit_stack = None + self._session = None + self._tools = [] + + @property + def tools(self) -> tuple[Any, ...]: + """Tuple of :class:`Tool` instances after :meth:`open`.""" + return tuple(self._tools) + + +class _MCPTool: + """Single tool backed by an MCP session.""" + + output_schema = None + wants_state = False + + def __init__( + self, + *, + session: Any, + remote_name: str, + description: str, + input_schema: Mapping[str, Any], + exposed_name: str, + request_timeout: float, + ) -> None: + self._session = session + self._remote_name = remote_name + self.name = exposed_name + self.description = description + self.input_schema = dict(input_schema) + self._timeout = request_timeout + + async def setup(self) -> None: + # Session is opened by MCPToolset; nothing per-tool. + pass + + async def teardown(self) -> None: + pass + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + try: + response = await asyncio.wait_for( + self._session.call_tool( + self._remote_name, dict(args) + ), + timeout=self._timeout, + ) + except asyncio.TimeoutError as e: + raise ToolError( + f"MCP call {self._remote_name!r} timed out after " + f"{self._timeout}s" + ) from e + except Exception as e: # noqa: BLE001 + raise ToolError(f"MCP call {self._remote_name!r} failed: {e}") from e + # The MCP response is a list of content items; we coerce to text. + content = getattr(response, "content", None) or [] + text_parts: list[str] = [] + for item in content: + text = getattr(item, "text", None) + if text is not None: + text_parts.append(text) + else: + # JSON-serialise structured content fragments. + text_parts.append( + json.dumps( + getattr(item, "model_dump", lambda: str(item))(), + ensure_ascii=False, + ) + ) + return ToolResult( + parts=(TextPart(text="\n".join(text_parts)),), + is_error=bool(getattr(response, "isError", False)), + ) + + +__all__ = ["MCPServerConfig", "MCPToolset"] diff --git a/torchrl/envs/llm/transforms/browser.py b/torchrl/envs/llm/transforms/browser.py index fa58f72fd98..7c89c9deb5e 100644 --- a/torchrl/envs/llm/transforms/browser.py +++ b/torchrl/envs/llm/transforms/browser.py @@ -108,6 +108,11 @@ class BrowserTransform(SimpleToolTransform): For a complete example of how to use this transform, see the LLM Tools tutorial in the documentation. + .. seealso:: + For a parallel-dispatched, sandboxed agent loop pair this with + :func:`~torchrl.envs.llm.agentic.tools.as_tool` and drop the + result into :class:`~torchrl.envs.llm.agentic.ToolCompose`. + Args: allowed_domains (list[str], optional): List of allowed domains. If None, all domains are allowed. headless (bool): Whether to run browser in headless mode. Defaults to True. diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 662deb47665..8ba38edc24a 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -411,6 +411,10 @@ class XMLBlockParser: Parses tool calls in the format: {"arg": "value"} + .. seealso:: + :class:`~torchrl.envs.llm.agentic.parsers.XMLToolCallParser` is the + modern equivalent with a stable ``call_id`` invariant. + Examples: >>> parser = XMLBlockParser() >>> response = '{"query": "torchrl"}\\nSome text.' @@ -463,6 +467,10 @@ def repl(m: re.Match) -> str: class JSONCallParser: """Parser for JSON-style function-calling responses. + .. seealso:: + :class:`~torchrl.envs.llm.agentic.parsers.JSONToolCallParser` is + the modern equivalent with a stable ``call_id`` invariant. + Expects responses in the format:: { @@ -526,6 +534,13 @@ class ExecuteToolsInOrder(ToolTransformBase): The transform integrates naturally with TorchRL's LLM environments and can read/write conversation history alongside other transforms. + .. seealso:: + :class:`~torchrl.envs.llm.agentic.ToolCompose` is the modern, + async-first replacement and dispatches tools concurrently. Existing + tool services that don't yet have native :class:`Tool` subclasses + can be lifted via :func:`~torchrl.envs.llm.agentic.tools.as_tool` + and dropped into ``ToolCompose`` without rewriting. + Args: registry (ToolRegistry): Registry containing available tool services. parser (LLMToolParser): Parser for extracting tool calls from LLM output. @@ -1054,6 +1069,16 @@ class PythonInterpreter(ToolTransformBase): This transform inherits from :class:`ToolTransformBase` and handles all the boilerplate for history extraction, batch processing, and result injection. + .. seealso:: + :class:`~torchrl.envs.llm.agentic.PythonTool` is the modern, + sandboxed equivalent. It runs code in a hardened + :class:`~torchrl.envs.llm.agentic.sandbox.Sandbox` (bubblewrap on + Linux, sandbox-exec on macOS) using a stateful + :class:`~torchrl.envs.llm.agentic.repl.JupyterRepl` or + :class:`~torchrl.envs.llm.agentic.repl.SubprocessRepl`. Pair it + with :class:`~torchrl.envs.llm.agentic.ToolCompose` for parallel + dispatch alongside other tools. + Args: tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer). tool_name: The name of the tool in the chat history. Defaults to `"tool"`. @@ -1390,6 +1415,13 @@ class SimpleToolTransform(ToolTransformBase): This is a lightweight alternative to MCPToolTransform for simple use cases where you don't need the full Model Context Protocol infrastructure. + .. seealso:: + Write a native :class:`~torchrl.envs.llm.agentic.Tool` subclass and + register it with :class:`~torchrl.envs.llm.agentic.ToolCompose` for + parallel dispatch and async lifecycle. To migrate an existing + callable-dict tool without rewriting, lift this transform via + :func:`~torchrl.envs.llm.agentic.tools.as_tool`. + Args: tools (dict[str, Callable]): Dictionary mapping tool names to their implementation functions. Each function should accept kwargs matching its expected parameters. @@ -1525,6 +1557,14 @@ class MCPToolTransform(ToolTransformBase): MCP library. It runs async operations in a background thread to work with TorchRL's synchronous transform API. + .. seealso:: + :class:`~torchrl.envs.llm.agentic.MCPToolset` is the modern, + natively-async replacement. An MCP server's ``tools/list`` is + materialised as N :class:`~torchrl.envs.llm.agentic.Tool` instances, + each with its native schema. Drop them into + :class:`~torchrl.envs.llm.agentic.ToolCompose` alongside other + tools for concurrent dispatch. + Args: servers (dict[str, dict]): Dictionary mapping server names to their configurations. Each config should have: diff --git a/tutorials/sphinx-tutorials/llm_agentic.py b/tutorials/sphinx-tutorials/llm_agentic.py new file mode 100644 index 00000000000..a6625f9c18f --- /dev/null +++ b/tutorials/sphinx-tutorials/llm_agentic.py @@ -0,0 +1,218 @@ +""" +Agentic ChatEnv: parallel tool dispatch with sandboxed REPL +=========================================================== + +**Author**: `Vincent Moens `_ + +.. _llm_agentic: + +This tutorial walks through building a SOTA agentic loop on top of +:class:`~torchrl.envs.llm.ChatEnv`: register a few tools, drop a +:class:`~torchrl.envs.llm.agentic.ToolCompose` into the env, and let the +LLM call them. Tool calls within a single response run **concurrently**, +Python execution is sandboxed, and any existing tool transform +(``PythonInterpreter``, ``BrowserTransform``, ``MCPToolTransform``, +``SimpleToolTransform``) plugs in alongside native tools via +:func:`~torchrl.envs.llm.agentic.tools.as_tool`. + +What you will learn +------------------- + +- How to compose :class:`~torchrl.envs.llm.agentic.Tool` instances under + :class:`~torchrl.envs.llm.agentic.ToolCompose`. +- How to pick a sandbox backend and a stateful REPL. +- How to mix multiple parser families (XML / JSON-block / OpenAI tool + calls / Anthropic tool use) under one orchestrator. +- How to migrate an existing tool transform into the new stack with + zero rewriting. +""" + +##################################################################### +# Why this exists +# --------------- +# +# The legacy :class:`~torchrl.envs.llm.transforms.ExecuteToolsInOrder` +# is a clean orchestrator but its dispatch is strictly sequential. +# Modern agent loops issue several independent calls per turn (search + +# read + compute) and pay a large wall-clock cost when those run one +# after the other. +# +# The agentic toolkit fixes this by: +# +# 1. Making tools async-first. Each +# :meth:`~torchrl.envs.llm.agentic.Tool.run` is a coroutine; the +# dispatcher uses :func:`asyncio.gather`. +# 2. Owning the parser at the +# :class:`~torchrl.envs.llm.agentic.ToolCompose` level so the +# response is parsed once, not once per transform. +# 3. Pinning a stable ``call_id`` for every parsed call so results +# correlate across the dispatch boundary -- crucial for OpenAI / +# Anthropic round-trips. +# 4. Defaulting to a hardened sandbox for code execution (bubblewrap +# on Linux, sandbox-exec on macOS) instead of running a bare +# subprocess in the host process. + +##################################################################### +# A minimal agentic loop +# ---------------------- +# +# We register two tools: a sandboxed Python REPL and an explicit +# :class:`~torchrl.envs.llm.agentic.StopTool`. The LLM is expected to +# emit XML-style calls. + +from tensordict import TensorDict, set_list_to_stack +from torchrl.data.llm import History +from torchrl.envs import TransformedEnv +from torchrl.envs.llm import ChatEnv +from torchrl.envs.llm.agentic import ( + PythonTool, + StopTool, + ToolCompose, +) +from torchrl.envs.llm.agentic.parsers import XMLToolCallParser +from torchrl.envs.llm.agentic.repl import SubprocessRepl +from torchrl.envs.llm.agentic.sandbox import default_sandbox + +set_list_to_stack(True).set() + +sandbox = default_sandbox() +repl = SubprocessRepl(sandbox) +env = TransformedEnv( + ChatEnv(batch_size=(1,), input_mode="history"), + ToolCompose( + tools=[PythonTool(repl=repl), StopTool()], + parser=XMLToolCallParser(), + ), +) + +obs = env.reset( + TensorDict({"query": "Compute 2+2 in python."}, batch_size=(1,)) +) + +# Stand-in for an LLM response; in real use this comes from a policy. +fake_response = '{"code": "print(2+2)"}' +obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=fake_response).view(1, 1), dim=-1, +) +nxt = env.step(obs) +print(nxt[("next", "history")].prompt[0][-1].content) + +##################################################################### +# Switching parser family +# ----------------------- +# +# The same env shape works against any policy that emits structured +# tool calls. Swap the parser to match the model's protocol: + +from torchrl.envs.llm.agentic.parsers import ( # noqa: E402 + AnthropicToolUseParser, + JSONToolCallParser, + OpenAIToolCallParser, +) + +# OpenAI / vLLM-with-tools: +# ToolCompose(tools=[...], parser=OpenAIToolCallParser()) +# Anthropic Messages API: +# ToolCompose(tools=[...], parser=AnthropicToolUseParser()) +# Plain JSON envelope: +# ToolCompose(tools=[...], parser=JSONToolCallParser()) + +##################################################################### +# Parallel dispatch in action +# --------------------------- +# +# Three independent tools, each waiting 500ms on a network call, +# complete in roughly 500ms total -- not 1.5s -- because +# :class:`~torchrl.envs.llm.agentic.ToolCompose` runs them concurrently. +# The benchmark group ``agentic-dispatch`` in +# ``benchmarks/test_llm.py`` pins this property in CI. + +##################################################################### +# Migrating from legacy transforms +# -------------------------------- +# +# If you have existing user code built on +# :class:`~torchrl.envs.llm.transforms.PythonInterpreter`, +# :class:`~torchrl.envs.llm.transforms.BrowserTransform`, +# :class:`~torchrl.envs.llm.transforms.MCPToolTransform`, or +# :class:`~torchrl.envs.llm.transforms.SimpleToolTransform`, you don't +# have to rewrite it. Lift the existing transform into a new-style +# :class:`~torchrl.envs.llm.agentic.Tool` via +# :func:`~torchrl.envs.llm.agentic.tools.as_tool`: +# +# .. code-block:: python +# +# from torchrl.envs.llm.transforms import PythonInterpreter +# from torchrl.envs.llm.agentic import ToolCompose +# from torchrl.envs.llm.agentic.tools import as_tool +# from torchrl.envs.llm.agentic.parsers import XMLToolCallParser +# +# legacy_python = as_tool( +# PythonInterpreter(persistent=True), +# name="python", +# input_schema={ +# "type": "object", +# "properties": {"code": {"type": "string"}}, +# "required": ["code"], +# }, +# ) +# +# env = TransformedEnv( +# ChatEnv(batch_size=(1,), input_mode="history"), +# ToolCompose(tools=[legacy_python], parser=XMLToolCallParser()), +# ) +# +# The legacy transform keeps its existing semantics; it now participates +# in parallel dispatch alongside any native :class:`Tool` you add. + +##################################################################### +# Connecting an MCP server +# ------------------------ +# +# The Model Context Protocol turns one server into many tools. Use +# :class:`~torchrl.envs.llm.agentic.MCPToolset` to discover them and +# drop the result straight into +# :class:`~torchrl.envs.llm.agentic.ToolCompose`: +# +# .. code-block:: python +# +# import asyncio +# from torchrl.envs.llm.agentic import ( +# MCPServerConfig, MCPToolset, ToolCompose, +# ) +# from torchrl.envs.llm.agentic.parsers import XMLToolCallParser +# +# async def make_env(): +# pool = MCPToolset( +# MCPServerConfig(command="npx", +# args=("@browsermcp/mcp@latest",)) +# ) +# await pool.open() +# compose = ToolCompose( +# tools=list(pool.tools), +# parser=XMLToolCallParser(), +# ) +# return compose, pool +# +# compose, pool = asyncio.run(make_env()) + +##################################################################### +# Conclusion +# ---------- +# +# The combination of *one parse per turn*, *parallel dispatch*, and +# *hardened sandboxes* turns ChatEnv into a SOTA agent backbone without +# introducing a parallel taxonomy of "agent" classes. Tools stay +# composable, ``ChatEnv`` stays minimal, and existing tool transforms +# keep working. + +##################################################################### +# Further reading +# --------------- +# +# - :class:`~torchrl.envs.llm.agentic.ToolCompose` -- API reference. +# - :class:`~torchrl.envs.llm.agentic.sandbox.Sandbox` and +# :class:`~torchrl.envs.llm.agentic.repl.Repl` -- protocol details. +# - The migration table in the LLM Environments reference page. +# - ``benchmarks/test_llm.py::test_toolcompose_parallel_dispatch`` -- +# performance bench.