diff --git a/benchmarks/test_llm.py b/benchmarks/test_llm.py
index aeebbd62491..c729a4dc040 100644
--- a/benchmarks/test_llm.py
+++ b/benchmarks/test_llm.py
@@ -99,3 +99,99 @@ def setup():
warmup_rounds=3,
setup=setup,
)
+
+
+# ----- Agentic ToolCompose dispatch benchmarks -----
+
+
+class _SleepyTool:
+ """Bench-only tool: simulates a network/I/O call via asyncio.sleep."""
+
+ description = "sleep"
+ input_schema = {
+ "type": "object",
+ "properties": {"ms": {"type": "integer"}},
+ }
+ output_schema = None
+ wants_state = False
+
+ def __init__(self, name: str) -> None:
+ self.name = name
+
+ async def setup(self) -> None:
+ pass
+
+ async def teardown(self) -> None:
+ pass
+
+ async def run(self, args, ctx):
+ import asyncio as _asyncio
+
+ from torchrl.envs.llm.agentic import ToolResult
+
+ await _asyncio.sleep(args.get("ms", 100) / 1000)
+ return ToolResult.from_text("ok")
+
+
+@pytest.mark.benchmark(group="agentic-dispatch")
+@pytest.mark.parametrize("n_tools", [3, 8])
+def test_toolcompose_parallel_dispatch(benchmark, n_tools):
+ """Bench parallel ToolCompose dispatch.
+
+ With ``n_tools`` async tools each sleeping 50ms, parallel dispatch
+ should bottom out near 50ms regardless of ``n_tools``; serial
+ dispatch would scale linearly.
+ """
+ from torchrl.envs import TransformedEnv
+ from torchrl.envs.llm import ChatEnv
+ from torchrl.envs.llm.agentic import ToolCompose
+ from torchrl.envs.llm.agentic.parsers import XMLToolCallParser
+
+ set_list_to_stack(True).set()
+ base = ChatEnv(batch_size=(1,), input_mode="history")
+ tools = [_SleepyTool(f"t{i}") for i in range(n_tools)]
+ env = TransformedEnv(
+ base, ToolCompose(tools=tools, parser=XMLToolCallParser())
+ )
+
+ fake = "".join(
+ f'{{"ms": 50}}'
+ for i in range(n_tools)
+ )
+
+ def go():
+ obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,)))
+ obs["history"].full = obs["history"].prompt.extend(
+ History(role="assistant", content=fake).view(1, 1), dim=-1
+ )
+ env.step(obs)
+
+ benchmark(go)
+
+
+@pytest.mark.benchmark(group="agentic-dispatch")
+def test_toolcompose_single_call_baseline(benchmark):
+ """One-call baseline so the n=3 / n=8 numbers are interpretable."""
+ from torchrl.envs import TransformedEnv
+ from torchrl.envs.llm import ChatEnv
+ from torchrl.envs.llm.agentic import ToolCompose
+ from torchrl.envs.llm.agentic.parsers import XMLToolCallParser
+
+ set_list_to_stack(True).set()
+ base = ChatEnv(batch_size=(1,), input_mode="history")
+ env = TransformedEnv(
+ base,
+ ToolCompose(
+ tools=[_SleepyTool("t0")], parser=XMLToolCallParser()
+ ),
+ )
+ fake = '{"ms": 50}'
+
+ def go():
+ obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,)))
+ obs["history"].full = obs["history"].prompt.extend(
+ History(role="assistant", content=fake).view(1, 1), dim=-1
+ )
+ env.step(obs)
+
+ benchmark(go)
diff --git a/test/llm/test_agentic.py b/test/llm/test_agentic.py
index 6816ac56bd9..666707a0393 100644
--- a/test/llm/test_agentic.py
+++ b/test/llm/test_agentic.py
@@ -15,15 +15,26 @@
import asyncio
import json
import sys
+import time as _time
import warnings
+from typing import ClassVar
import pytest
-from tensordict import set_list_to_stack
+from tensordict import set_list_to_stack, TensorDict
+from torchrl.data.llm import History
+from torchrl.envs import TransformedEnv
+from torchrl.envs.llm import ChatEnv
from torchrl.envs.llm.agentic import (
ParsedCall,
+ PythonTool,
+ RateLimiter,
+ ShellTool,
+ StopTool,
Tool,
ToolCallParser,
+ ToolCompose,
+ ToolContext,
ToolResult,
validate_args,
)
@@ -43,6 +54,8 @@
)
from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap
from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import _has_sandbox_exec
+from torchrl.envs.llm.agentic.tools import as_tool
+from torchrl.envs.llm.transforms import IncrementalTokenizer
@pytest.fixture(scope="module", autouse=True)
@@ -431,3 +444,365 @@ async def go():
assert r2.stdout.strip() == "42"
_run(go())
+
+
+# ----- ToolCompose, builtins, legacy adapter -----
+
+
+class _Sleeper:
+ description: ClassVar[str] = "sleep N ms"
+ input_schema = {
+ "type": "object",
+ "properties": {"ms": {"type": "integer"}},
+ }
+ output_schema = None
+ wants_state = False
+
+ def __init__(self, name):
+ self.name = name
+
+ async def setup(self):
+ pass
+
+ async def teardown(self):
+ pass
+
+ async def run(self, args, ctx):
+ await asyncio.sleep(args.get("ms", 100) / 1000)
+ return ToolResult.from_text(f"{self.name}-done")
+
+
+class _Stateful:
+ name: ClassVar[str] = "stateful"
+ description: ClassVar[str] = "needs state"
+ input_schema = {"type": "object"}
+ output_schema = None
+ wants_state = True
+ received_state = None
+
+ async def setup(self):
+ pass
+
+ async def teardown(self):
+ pass
+
+ async def run(self, args, ctx):
+ type(self).received_state = ctx.state
+ return ToolResult.from_text("ok")
+
+
+class _Boom:
+ name: ClassVar[str] = "boom"
+ description = "always fails"
+ input_schema = {"type": "object"}
+ output_schema = None
+ wants_state = False
+
+ async def setup(self):
+ pass
+
+ async def teardown(self):
+ pass
+
+ async def run(self, args, ctx):
+ raise RuntimeError("boom")
+
+
+def _agentic_env(tools, parser=None):
+ parser = parser or XMLToolCallParser()
+ base = ChatEnv(batch_size=(1,), input_mode="history")
+ return TransformedEnv(base, ToolCompose(tools=tools, parser=parser))
+
+
+def _push_assistant(obs, response: str):
+ obs["history"].full = obs["history"].prompt.extend(
+ History(role="assistant", content=response).view(1, 1), dim=-1
+ )
+
+
+class TestToolCompose:
+ def test_rejects_non_tool(self):
+ with pytest.raises(TypeError):
+ ToolCompose(tools=[object()], parser=XMLToolCallParser())
+
+ def test_rejects_duplicate_names(self):
+ with pytest.raises(ValueError):
+ ToolCompose(
+ tools=[_Sleeper("dup"), _Sleeper("dup")],
+ parser=XMLToolCallParser(),
+ )
+
+ def test_append_transform_blocked(self):
+ compose = ToolCompose(
+ tools=[StopTool()], parser=XMLToolCallParser()
+ )
+ with pytest.raises(TypeError):
+ compose.append_transform(IncrementalTokenizer)
+
+ def test_lookup_by_name(self):
+ compose = ToolCompose(
+ tools=[StopTool()], parser=XMLToolCallParser()
+ )
+ assert "stop" in compose
+ assert compose["stop"].name == "stop"
+
+ def test_parallel_dispatch_wall_time(self):
+ env = _agentic_env([_Sleeper("a"), _Sleeper("b"), _Sleeper("c")])
+ obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,)))
+ _push_assistant(
+ obs,
+ '{"ms": 500}'
+ '{"ms": 500}'
+ '{"ms": 500}',
+ )
+ t0 = _time.monotonic()
+ nxt = env.step(obs)
+ elapsed = _time.monotonic() - t0
+ # Three 500ms tools must run concurrently: total < 0.8s.
+ assert elapsed < 0.9, (
+ f"parallel dispatch took {elapsed:.2f}s; expected < 0.8s"
+ )
+ assert bool(nxt.get(("next", "agentic", "any_tool_calls")).item())
+ assert not bool(nxt.get(("next", "agentic", "stop_requested")).item())
+
+ def test_stop_tool_terminates(self):
+ env = _agentic_env([StopTool()])
+ obs = env.reset(TensorDict({"query": "stop"}, batch_size=(1,)))
+ _push_assistant(obs, '{"reason":"done"}')
+ nxt = env.step(obs)
+ assert bool(nxt.get(("next", "agentic", "stop_requested")).item())
+
+ def test_no_tool_calls_passthrough(self):
+ env = _agentic_env([StopTool()])
+ obs = env.reset(TensorDict({"query": "nothing"}, batch_size=(1,)))
+ _push_assistant(obs, "I have nothing to call.")
+ nxt = env.step(obs)
+ assert not bool(nxt.get(("next", "agentic", "any_tool_calls")).item())
+
+ def test_unknown_tool_reports_error(self):
+ env = _agentic_env([StopTool()])
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}')
+ nxt = env.step(obs)
+ assert bool(nxt.get(("next", "agentic", "any_error")).item())
+
+ def test_failure_isolation(self):
+ env = _agentic_env([_Sleeper("a"), _Boom()])
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(
+ obs,
+ '{"ms": 5}'
+ '{}',
+ )
+ nxt = env.step(obs)
+ # both calls fired; one failed, one succeeded.
+ assert bool(nxt.get(("next", "agentic", "any_error")).item())
+ assert bool(nxt.get(("next", "agentic", "any_tool_calls")).item())
+ # The history should have both tool messages appended.
+ prompt = nxt[("next", "history")].prompt
+ # Assistant message + 2 tool messages = at least 3 entries beyond
+ # the original prompt.
+ assert len(prompt[0]) >= 3
+
+ def test_stable_call_id_round_trip(self):
+ captured: list[str] = []
+
+ class _Recorder:
+ name: ClassVar[str] = "rec"
+ description = "records call_id"
+ input_schema = {"type": "object"}
+ output_schema = None
+ wants_state = False
+
+ async def setup(self):
+ pass
+
+ async def teardown(self):
+ pass
+
+ async def run(self, args, ctx):
+ captured.append(ctx.call_id)
+ return ToolResult.from_text(f"id={ctx.call_id}")
+
+ env = _agentic_env([_Recorder()])
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}')
+ nxt = env.step(obs)
+ assert captured == ["my-id"]
+ # The rendered tool message must reference the same call_id.
+ prompt = nxt[("next", "history")].prompt
+ last_msg = prompt[0][-1]
+ assert "my-id" in last_msg.content
+
+ def test_pass_state_to_tools(self):
+ _Stateful.received_state = None
+ base = ChatEnv(batch_size=(1,), input_mode="history")
+ env = TransformedEnv(
+ base,
+ ToolCompose(
+ tools=[_Stateful()],
+ parser=XMLToolCallParser(),
+ pass_state_to_tools=True,
+ ),
+ )
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}')
+ env.step(obs)
+ assert _Stateful.received_state is not None
+
+ def test_pass_state_off_means_no_state(self):
+ _Stateful.received_state = None
+ env = _agentic_env([_Stateful()]) # pass_state_to_tools defaults False
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}')
+ env.step(obs)
+ assert _Stateful.received_state is None
+
+ def test_rate_limit_serializes_concurrent_calls(self):
+ # max_concurrent=1 forces 3 calls of 200ms each to take >= 600ms.
+ slow = _Sleeper("slow")
+ compose = ToolCompose(
+ tools=[slow],
+ parser=XMLToolCallParser(),
+ rate_limits={"slow": RateLimiter(max_concurrent=1)},
+ )
+ base = ChatEnv(batch_size=(1,), input_mode="history")
+ env = TransformedEnv(base, compose)
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(
+ obs,
+ '{"ms": 200}'
+ '{"ms": 200}'
+ '{"ms": 200}',
+ )
+ t0 = _time.monotonic()
+ env.step(obs)
+ elapsed = _time.monotonic() - t0
+ assert elapsed >= 0.55, (
+ f"rate-limited dispatch should serialize: got {elapsed:.2f}s"
+ )
+
+ def test_argument_validation(self):
+ class _NeedsCode:
+ name: ClassVar[str] = "needs"
+ description = ""
+ input_schema = {
+ "type": "object",
+ "properties": {"code": {"type": "string"}},
+ "required": ["code"],
+ }
+ output_schema = None
+ wants_state = False
+
+ async def setup(self):
+ pass
+
+ async def teardown(self):
+ pass
+
+ async def run(self, args, ctx): # pragma: no cover - never reached
+ return ToolResult.from_text("hit")
+
+ env = _agentic_env([_NeedsCode()])
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}') # missing required
+ nxt = env.step(obs)
+ assert bool(nxt.get(("next", "agentic", "any_error")).item())
+
+ def test_nested_loop_safety(self):
+ # When the caller already owns an event loop, ToolCompose._step
+ # must still complete (offload to a worker thread).
+ async def go():
+ env = _agentic_env([StopTool()])
+ obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,)))
+ _push_assistant(obs, '{}')
+ return env.step(obs)
+
+ nxt = _run(go())
+ assert bool(nxt.get(("next", "agentic", "stop_requested")).item())
+
+
+class TestPythonTool:
+ def test_state_persists_across_calls(self):
+ async def go():
+ async with UnsafeSubprocessSandbox(
+ ResourceLimits(wall_seconds=10)
+ ) as s:
+ tool = PythonTool(repl=SubprocessRepl(s))
+ await tool.setup()
+ ctx = ToolContext(call_id="c1")
+ r1 = await tool.run({"code": "x = 41"}, ctx)
+ assert not r1.is_error
+ r2 = await tool.run({"code": "print(x + 1)"}, ctx)
+ assert not r2.is_error
+ assert "42" in r2.text
+ await tool.teardown()
+
+ _run(go())
+
+ def test_error_marked_is_error(self):
+ async def go():
+ async with UnsafeSubprocessSandbox(
+ ResourceLimits(wall_seconds=10)
+ ) as s:
+ tool = PythonTool(repl=SubprocessRepl(s))
+ await tool.setup()
+ r = await tool.run({"code": "1/0"}, ToolContext(call_id="c"))
+ assert r.is_error
+ assert "ZeroDivisionError" in r.text
+ await tool.teardown()
+
+ _run(go())
+
+
+class TestShellTool:
+ def test_runs_argv(self):
+ async def go():
+ async with UnsafeSubprocessSandbox(
+ ResourceLimits(wall_seconds=5)
+ ) as s:
+ tool = ShellTool(s)
+ await tool.setup()
+ r = await tool.run(
+ {"argv": ["/bin/echo", "hi"]}, ToolContext(call_id="c")
+ )
+ assert not r.is_error
+ assert "hi" in r.text
+ # Don't tear down s twice -- ShellTool teardown closes it.
+
+ _run(go())
+
+
+class TestLegacyAdapter:
+ def test_lifts_legacy_transform(self):
+ # Use a tiny duck-typed legacy class instead of pulling in the
+ # full PythonInterpreter, which has its own subprocess pool.
+ class _LegacyAdder:
+ tool_role = "tool"
+
+ def _process_batch_item(self, content, index):
+ # Echo the captured XML so the assertion can find it.
+ if "{"a": 1, "b": 2}')
+ nxt = env.step(obs)
+ prompt = nxt[("next", "history")].prompt
+ # The last appended message should be the tool result containing
+ # the legacy output.
+ assert "legacy got" in prompt[0][-1].content
diff --git a/torchrl/envs/llm/agentic/__init__.py b/torchrl/envs/llm/agentic/__init__.py
index 7bdea7108dc..a18f6a6f698 100644
--- a/torchrl/envs/llm/agentic/__init__.py
+++ b/torchrl/envs/llm/agentic/__init__.py
@@ -16,6 +16,7 @@
"""
from __future__ import annotations
+from .compose import DispatchResult, ToolCompose
from .protocols import (
FileRefPart,
ImagePart,
@@ -30,21 +31,39 @@
ToolResult,
ToolResultPart,
)
+from .rate_limit import RateLimiter
from .schema import json_schema_from_pydantic, validate_args
+from .tools import (
+ FileReadTool,
+ PythonTool,
+ ShellTool,
+ StopSignal,
+ StopTool,
+ as_tool,
+)
__all__ = [
+ "DispatchResult",
+ "FileReadTool",
"FileRefPart",
"ImagePart",
"JsonPart",
"ParseResult",
"ParsedCall",
+ "PythonTool",
+ "RateLimiter",
+ "ShellTool",
+ "StopSignal",
+ "StopTool",
"TextPart",
"Tool",
"ToolCallParser",
+ "ToolCompose",
"ToolContext",
"ToolError",
"ToolResult",
"ToolResultPart",
+ "as_tool",
"json_schema_from_pydantic",
"validate_args",
]
diff --git a/torchrl/envs/llm/agentic/compose.py b/torchrl/envs/llm/agentic/compose.py
new file mode 100644
index 00000000000..5b34c8a2b92
--- /dev/null
+++ b/torchrl/envs/llm/agentic/compose.py
@@ -0,0 +1,519 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""ToolCompose: parallel async tool dispatch over a ChatEnv history.
+
+A :class:`ToolCompose` is a :class:`~torchrl.envs.transforms.Compose`
+subclass that:
+
+1. owns a :class:`~torchrl.envs.llm.agentic.ToolCallParser`,
+2. holds a fixed set of :class:`~torchrl.envs.llm.agentic.Tool` instances
+ (raises :class:`TypeError` on non-Tool insert),
+3. on each step parses the latest assistant message *once*, dispatches
+ matched tools concurrently via :func:`asyncio.gather`, renders each
+ result through the parser, and extends the
+ :class:`~torchrl.data.llm.History` with the resulting tool messages,
+4. surfaces ``("agentic", "any_tool_calls")`` and
+ ``("agentic", "stop_requested")`` keys in the step output for the env
+ to use as termination signals.
+
+The ChatEnv is unchanged; ``ToolCompose`` lives entirely in transform
+space.
+"""
+from __future__ import annotations
+
+import asyncio
+import json
+import threading
+import uuid
+from collections.abc import Iterable, Mapping
+from concurrent.futures import Future
+from dataclasses import dataclass, field
+from typing import Any
+
+import torch
+from tensordict import lazy_stack, TensorDictBase
+from torchrl._utils import logger as torchrl_logger
+from torchrl.data.llm import History
+from torchrl.envs.transforms import Compose, Transform
+
+from .protocols import (
+ ParsedCall,
+ ParseResult,
+ Tool,
+ ToolCallParser,
+ ToolContext,
+ ToolError,
+ ToolResult,
+)
+from .rate_limit import RateLimiter
+from .schema import validate_args
+from .tools.builtin import StopSignal
+
+
+@dataclass
+class DispatchResult:
+ """Aggregate outcome of one :meth:`ToolCompose._dispatch_one` call.
+
+ Attributes:
+ cleaned_text: Assistant text with tool-call syntax stripped.
+ calls: Parsed calls in emission order.
+ results: Tool results, aligned with ``calls``.
+ any_error: ``True`` if any tool failed (including unknown tool).
+ stop_requested: ``True`` if any tool raised :class:`StopSignal`.
+ """
+
+ cleaned_text: str = ""
+ calls: tuple[ParsedCall, ...] = ()
+ results: tuple[ToolResult, ...] = ()
+ any_error: bool = False
+ stop_requested: bool = False
+
+
+class _ToolTransformShim(Transform):
+ """Thin :class:`Transform` wrapper around a :class:`Tool`.
+
+ Stored inside :class:`ToolCompose.transforms` so the parent ``Compose``
+ machinery (module hierarchy, device moves, training-mode propagation)
+ keeps working. The shim has no ``_step`` of its own -- ``ToolCompose``
+ drives dispatch directly and consults this shim only as a holder.
+ """
+
+ def __init__(self, tool: Tool) -> None:
+ super().__init__()
+ # Do NOT register the tool as a submodule -- tools are not nn.Module.
+ # Storing in __dict__ keeps it out of state_dict.
+ object.__setattr__(self, "tool", tool)
+
+ @property
+ def name(self) -> str:
+ return self.tool.name
+
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: # noqa: D401
+ # ToolCompose drives dispatch directly; the shim is a pass-through.
+ return next_tensordict
+
+
+class ToolCompose(Compose):
+ """A :class:`Compose` of :class:`~torchrl.envs.llm.agentic.Tool` objects
+ with parallel async dispatch.
+
+ Args:
+ tools: The set of tools the LLM may call. Each must conform to the
+ :class:`~torchrl.envs.llm.agentic.Tool` protocol (``name``,
+ ``input_schema``, async ``run``). Inserting a non-Tool (e.g. a
+ plain :class:`Transform`) raises :class:`TypeError`.
+ parser: The :class:`~torchrl.envs.llm.agentic.ToolCallParser` used
+ to extract calls from each assistant message and render
+ results back as tool messages.
+ rate_limits: Optional per-tool :class:`RateLimiter` map (keyed by
+ tool name). Tools without an entry are unthrottled.
+ per_call_timeout: Default per-call timeout in seconds. ``None``
+ means rely on tool/repl timeouts only.
+ pass_state_to_tools: If ``True`` and a tool has
+ ``wants_state=True``, ``ctx.state`` is populated with a
+ filtered read-only view of the env tensordict (mirrors the
+ legacy ``ExecuteToolsInOrder`` knob).
+ tool_role: Role string used when injecting tool messages into
+ :class:`~torchrl.data.llm.History` (default ``"tool"``).
+ validate_inputs: If ``True`` (default), validate ``args`` against
+ ``tool.input_schema`` before dispatch. Schema mismatches are
+ reported as :class:`ToolResult` with ``is_error=True``.
+
+ Example:
+ >>> from torchrl.envs import TransformedEnv
+ >>> from torchrl.envs.llm import ChatEnv # doctest: +SKIP
+ >>> from torchrl.envs.llm.agentic import ToolCompose # doctest: +SKIP
+ >>> from torchrl.envs.llm.agentic.parsers import XMLToolCallParser
+ >>> from torchrl.envs.llm.agentic.tools import StopTool
+ >>> compose = ToolCompose(
+ ... tools=[StopTool()],
+ ... parser=XMLToolCallParser(),
+ ... )
+ """
+
+ def __init__(
+ self,
+ *,
+ tools: Iterable[Tool],
+ parser: ToolCallParser,
+ rate_limits: Mapping[str, RateLimiter] | None = None,
+ per_call_timeout: float | None = None,
+ pass_state_to_tools: bool = False,
+ tool_role: str = "tool",
+ validate_inputs: bool = True,
+ ) -> None:
+ tool_list = list(tools)
+ for t in tool_list:
+ if not _is_tool(t):
+ raise TypeError(
+ f"ToolCompose accepts Tool objects only; got "
+ f"{type(t).__name__!r}. Wrap a legacy transform with "
+ "torchrl.envs.llm.agentic.tools.as_tool(...)."
+ )
+ if not isinstance(parser, ToolCallParser):
+ raise TypeError(
+ "parser must implement ToolCallParser; got "
+ f"{type(parser).__name__!r}"
+ )
+ names = [t.name for t in tool_list]
+ if len(names) != len(set(names)):
+ seen: set[str] = set()
+ dups = [n for n in names if n in seen or seen.add(n)] # type: ignore[func-returns-value]
+ raise ValueError(f"duplicate tool names: {dups!r}")
+ shims = [_ToolTransformShim(t) for t in tool_list]
+ super().__init__(*shims)
+ self._tool_list: list[Tool] = tool_list
+ self._tools_by_name: dict[str, Tool] = {t.name: t for t in tool_list}
+ self.parser = parser
+ self._rate_limits: dict[str, RateLimiter] = dict(rate_limits or {})
+ self._per_call_timeout = per_call_timeout
+ self._pass_state = pass_state_to_tools
+ self._tool_role = tool_role
+ self._validate_inputs = validate_inputs
+ self._setup_done = False
+
+ # ----- introspection helpers -----
+
+ @property
+ def tools(self) -> tuple[Tool, ...]:
+ return tuple(self._tool_list)
+
+ def __getitem__(self, key: str | int): # type: ignore[override]
+ if isinstance(key, str):
+ return self._tools_by_name[key]
+ return super().__getitem__(key)
+
+ def __contains__(self, key: object) -> bool: # type: ignore[override]
+ if isinstance(key, str):
+ return key in self._tools_by_name
+ return super().__contains__(key)
+
+ # ----- enforce Tool-only insertion -----
+
+ def append_transform(self, transform): # type: ignore[override]
+ # ``Compose.append_transform`` accepts arbitrary Transforms. We
+ # constrain to Tools to keep the dispatch invariants intact.
+ raise TypeError(
+ "ToolCompose does not accept arbitrary transforms. Use "
+ "append_tool(tool) instead."
+ )
+
+ def append_tool(self, tool: Tool) -> None:
+ """Append a :class:`Tool` to the dispatch set."""
+ if not _is_tool(tool):
+ raise TypeError(
+ f"append_tool requires a Tool; got {type(tool).__name__!r}"
+ )
+ if tool.name in self._tools_by_name:
+ raise ValueError(f"duplicate tool name: {tool.name!r}")
+ shim = _ToolTransformShim(tool)
+ self.transforms.append(shim)
+ shim.set_container(self)
+ self._tool_list.append(tool)
+ self._tools_by_name[tool.name] = tool
+
+ # ----- lifecycle -----
+
+ async def _setup_tools(self) -> None:
+ if self._setup_done:
+ return
+ for tool in self._tool_list:
+ try:
+ await tool.setup()
+ except Exception: # pragma: no cover -- per-tool setup is best-effort
+ torchrl_logger.exception(
+ "tool %r setup raised; continuing", tool.name
+ )
+ self._setup_done = True
+
+ async def _teardown_tools(self) -> None:
+ for tool in self._tool_list:
+ try:
+ await tool.teardown()
+ except Exception: # pragma: no cover
+ torchrl_logger.exception(
+ "tool %r teardown raised; continuing", tool.name
+ )
+ self._setup_done = False
+
+ def close(self) -> None: # type: ignore[override]
+ super().close()
+ try:
+ _run_async(self._teardown_tools())
+ except Exception: # pragma: no cover
+ torchrl_logger.exception("teardown_tools raised; continuing")
+
+ # ----- the step path -----
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ if next_tensordict.batch_dims > 1:
+ with next_tensordict.view(-1) as flat:
+ flat = self._step(tensordict, flat)
+ return next_tensordict
+ history = self._extract_history(next_tensordict)
+ last = history[..., -1]
+ contents = last.content
+ if isinstance(contents, str):
+ contents = [contents]
+ # asyncio.gather across batch items.
+ dispatch_results: list[DispatchResult] = _run_async(
+ self._dispatch_batch(list(contents), next_tensordict)
+ )
+ return self._inject_results(history, dispatch_results, next_tensordict)
+
+ def _extract_history(self, td: TensorDictBase) -> History:
+ parent = self.parent
+ if parent is None:
+ raise RuntimeError(
+ "ToolCompose must be used inside a TransformedEnv"
+ )
+ base_env = parent.base_env
+ if getattr(base_env, "input_mode", None) != "history":
+ raise RuntimeError(
+ "ToolCompose requires the underlying ChatEnv to use "
+ "input_mode='history' (got "
+ f"{getattr(base_env, 'input_mode', None)!r})"
+ )
+ return td["history"].prompt
+
+ async def _dispatch_batch(
+ self, contents: list[str], td: TensorDictBase
+ ) -> list[DispatchResult]:
+ await self._setup_tools()
+ # Each batch item runs concurrently with every other; within a
+ # batch item, calls also run concurrently.
+ return list(
+ await asyncio.gather(
+ *(
+ self._dispatch_one(content, td, batch_index=i)
+ for i, content in enumerate(contents)
+ )
+ )
+ )
+
+ async def _dispatch_one(
+ self, content: str, td: TensorDictBase, *, batch_index: int
+ ) -> DispatchResult:
+ parsed: ParseResult = self.parser.parse(content)
+ if not parsed.calls:
+ return DispatchResult(cleaned_text=parsed.text)
+ # Parallel run all calls in this batch item.
+ coros = [
+ self._run_one(call, td, batch_index=batch_index)
+ for call in parsed.calls
+ ]
+ outcomes = await asyncio.gather(*coros, return_exceptions=False)
+ results: list[ToolResult] = []
+ any_error = False
+ stop = False
+ for r in outcomes:
+ if isinstance(r, _StopMarker):
+ results.append(r.result)
+ stop = True
+ else:
+ results.append(r)
+ if r.is_error:
+ any_error = True
+ return DispatchResult(
+ cleaned_text=parsed.text,
+ calls=parsed.calls,
+ results=tuple(results),
+ any_error=any_error,
+ stop_requested=stop,
+ )
+
+ async def _run_one(
+ self, call: ParsedCall, td: TensorDictBase, *, batch_index: int
+ ):
+ tool = self._tools_by_name.get(call.tool)
+ if tool is None:
+ return ToolResult.from_text(
+ f"unknown tool: {call.tool!r}",
+ is_error=True,
+ meta={"call_id": call.call_id},
+ )
+ if self._validate_inputs:
+ try:
+ validate_args(call.args, tool.input_schema)
+ except Exception as e:
+ return ToolResult.from_text(
+ f"argument validation failed for {tool.name!r}: {e}",
+ is_error=True,
+ meta={"call_id": call.call_id},
+ )
+ ctx = ToolContext(
+ call_id=call.call_id,
+ tag=call.tag,
+ state=self._filter_state(td, batch_index) if self._pass_state and getattr(
+ tool, "wants_state", False
+ ) else None,
+ sandbox=getattr(tool, "sandbox", None),
+ repl=getattr(tool, "repl", None),
+ compose=self,
+ )
+ limiter = self._rate_limits.get(tool.name)
+ timeout = self._per_call_timeout
+ try:
+ if limiter is not None:
+ async with limiter.slot():
+ coro = tool.run(call.args, ctx)
+ if timeout is None:
+ return await coro
+ return await asyncio.wait_for(coro, timeout=timeout)
+ coro = tool.run(call.args, ctx)
+ if timeout is None:
+ return await coro
+ return await asyncio.wait_for(coro, timeout=timeout)
+ except StopSignal as s:
+ return _StopMarker(
+ ToolResult.from_text(
+ f"[stop] {s}", meta={"call_id": call.call_id, "stop": True}
+ )
+ )
+ except ToolError as e:
+ return ToolResult.from_text(
+ str(e), is_error=True, meta={"call_id": call.call_id}
+ )
+ except asyncio.TimeoutError:
+ return ToolResult.from_text(
+ f"tool {tool.name!r} timed out after {timeout}s",
+ is_error=True,
+ meta={"call_id": call.call_id, "timed_out": True},
+ )
+ except Exception as e: # noqa: BLE001 -- failure isolation
+ torchrl_logger.exception(
+ "tool %r raised; reporting as error", tool.name
+ )
+ return ToolResult.from_text(
+ f"{type(e).__name__}: {e}",
+ is_error=True,
+ meta={"call_id": call.call_id},
+ )
+
+ def _filter_state(
+ self, td: TensorDictBase, batch_index: int
+ ) -> TensorDictBase | None:
+ try:
+ view = td[batch_index] if td.batch_dims else td
+ except Exception: # pragma: no cover
+ return None
+ try:
+ return view.exclude("history")
+ except Exception: # pragma: no cover
+ return view
+
+ # ----- result injection -----
+
+ def _inject_results(
+ self,
+ history: History,
+ dispatches: list[DispatchResult],
+ td: TensorDictBase,
+ ) -> TensorDictBase:
+ # Build per-batch-item History extension lists and the agentic
+ # signals.
+ any_calls = [bool(d.calls) for d in dispatches]
+ any_error = [d.any_error for d in dispatches]
+ any_stop = [d.stop_requested for d in dispatches]
+
+ per_item_messages: list[list[History]] = []
+ for d in dispatches:
+ messages: list[History] = []
+ for call, result in zip(d.calls, d.results):
+ rendered = self.parser.render_result(call.call_id, result)
+ content = rendered.get("content", "")
+ if isinstance(content, list):
+ content = json.dumps(content, ensure_ascii=False)
+ messages.append(
+ History(role=self._tool_role, content=str(content))
+ )
+ per_item_messages.append(messages)
+
+ max_len = max((len(m) for m in per_item_messages), default=0)
+ if max_len > 0:
+ padded: list[list[History]] = []
+ for messages in per_item_messages:
+ if len(messages) == max_len:
+ padded.append(messages)
+ else:
+ padded.append(
+ messages
+ + [History(role="", content="")]
+ * (max_len - len(messages))
+ )
+ stacked = lazy_stack([lazy_stack(m) for m in padded])
+ history.extend(stacked, dim=-1)
+ td["history"].prompt = history
+
+ device = td.device
+ td.set(
+ ("agentic", "any_tool_calls"),
+ torch.tensor(any_calls, dtype=torch.bool, device=device),
+ )
+ td.set(
+ ("agentic", "any_error"),
+ torch.tensor(any_error, dtype=torch.bool, device=device),
+ )
+ td.set(
+ ("agentic", "stop_requested"),
+ torch.tensor(any_stop, dtype=torch.bool, device=device),
+ )
+ return td
+
+
+@dataclass
+class _StopMarker:
+ """Internal carrier flagging a :class:`StopSignal` from a tool."""
+
+ result: ToolResult
+
+
+def _is_tool(obj: Any) -> bool:
+ """Duck-typed Tool conformance check.
+
+ ``Tool`` is a runtime-checkable Protocol but Protocol checks treat
+ every required attribute as needing to be present at the *instance*
+ level; class-level ``ClassVar``s satisfy this in CPython but the check
+ is unreliable across versions when adapters set per-instance ``name``.
+ Reuse a simpler attribute-presence test.
+ """
+ required = ("name", "input_schema", "run", "setup", "teardown")
+ return all(hasattr(obj, a) for a in required) and callable(obj.run)
+
+
+# ----- async runner: nested-loop safe -----
+
+def _run_async(coro):
+ """Run ``coro`` to completion regardless of whether the caller is
+ inside an event loop.
+
+ - No running loop: :func:`asyncio.run`.
+ - Running loop: dispatch on a worker thread that owns its own loop
+ and join. (Necessary because Compose._step is sync and may be
+ called from inside a Jupyter-style outer loop.)
+ """
+ try:
+ asyncio.get_running_loop()
+ running = True
+ except RuntimeError:
+ running = False
+ if not running:
+ return asyncio.run(coro)
+ fut: Future = Future()
+
+ def _target():
+ try:
+ fut.set_result(asyncio.run(coro))
+ except BaseException as e: # noqa: BLE001
+ fut.set_exception(e)
+
+ t = threading.Thread(target=_target, daemon=True)
+ t.start()
+ return fut.result()
+
+
+__all__ = ["DispatchResult", "ToolCompose"]
diff --git a/torchrl/envs/llm/agentic/rate_limit.py b/torchrl/envs/llm/agentic/rate_limit.py
new file mode 100644
index 00000000000..f463be9b1e3
--- /dev/null
+++ b/torchrl/envs/llm/agentic/rate_limit.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""Asyncio rate limiter -- semaphore + token bucket.
+
+A :class:`RateLimiter` caps concurrent in-flight calls and (optionally) the
+sustained call rate. Used by :class:`~torchrl.envs.llm.agentic.ToolCompose`
+to throttle individual tools (e.g. a search API at 5 QPS) without blocking
+the rest of the dispatch.
+"""
+from __future__ import annotations
+
+import asyncio
+import time
+from contextlib import asynccontextmanager
+
+
+class RateLimiter:
+ """Combined semaphore + token-bucket throttle.
+
+ Args:
+ max_concurrent: Cap on simultaneously in-flight calls. ``None`` for
+ unlimited.
+ rate_per_second: Sustained refill rate. ``None`` disables the
+ token bucket; only the semaphore is enforced.
+ burst: Token-bucket capacity. Defaults to ``rate_per_second``.
+
+ Examples:
+ >>> import asyncio
+ >>> async def go():
+ ... limiter = RateLimiter(max_concurrent=2)
+ ... async with limiter.slot():
+ ... pass
+ >>> asyncio.run(go())
+ """
+
+ def __init__(
+ self,
+ *,
+ max_concurrent: int | None = None,
+ rate_per_second: float | None = None,
+ burst: float | None = None,
+ ) -> None:
+ self._sem: asyncio.Semaphore | None = (
+ asyncio.Semaphore(max_concurrent) if max_concurrent else None
+ )
+ self._rate = rate_per_second
+ self._capacity = burst if burst is not None else (rate_per_second or 0.0)
+ self._tokens = self._capacity
+ self._last = time.monotonic()
+ self._lock = asyncio.Lock()
+
+ async def _consume(self) -> None:
+ if not self._rate:
+ return
+ async with self._lock:
+ now = time.monotonic()
+ self._tokens = min(
+ self._capacity, self._tokens + (now - self._last) * self._rate
+ )
+ self._last = now
+ if self._tokens >= 1.0:
+ self._tokens -= 1.0
+ return
+ wait = (1.0 - self._tokens) / self._rate
+ await asyncio.sleep(wait)
+ async with self._lock:
+ self._tokens = max(0.0, self._tokens - 1.0)
+
+ @asynccontextmanager
+ async def slot(self):
+ """Acquire one slot. Blocks until both the semaphore and token
+ bucket allow."""
+ await self._consume()
+ if self._sem is None:
+ yield
+ return
+ async with self._sem:
+ yield
+
+
+__all__ = ["RateLimiter"]
diff --git a/torchrl/envs/llm/agentic/tools/__init__.py b/torchrl/envs/llm/agentic/tools/__init__.py
new file mode 100644
index 00000000000..2795b09c4fa
--- /dev/null
+++ b/torchrl/envs/llm/agentic/tools/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""Built-in tools and the legacy-transform adapter.
+
+- :class:`PythonTool`, :class:`ShellTool`, :class:`FileReadTool`,
+ :class:`StopTool` -- ready-to-use tools wired to the agentic
+ Sandbox/Repl primitives.
+- :func:`as_tool` -- adapter lifting any legacy
+ :class:`~torchrl.envs.llm.transforms.tools.ToolTransformBase` subclass
+ into a :class:`~torchrl.envs.llm.agentic.Tool` so existing user code
+ keeps working inside :class:`~torchrl.envs.llm.agentic.ToolCompose`.
+"""
+from __future__ import annotations
+
+from .builtin import FileReadTool, PythonTool, ShellTool, StopTool, StopSignal
+from .legacy_adapter import as_tool
+
+__all__ = [
+ "FileReadTool",
+ "PythonTool",
+ "ShellTool",
+ "StopSignal",
+ "StopTool",
+ "as_tool",
+]
diff --git a/torchrl/envs/llm/agentic/tools/builtin.py b/torchrl/envs/llm/agentic/tools/builtin.py
new file mode 100644
index 00000000000..f0c7dc688ea
--- /dev/null
+++ b/torchrl/envs/llm/agentic/tools/builtin.py
@@ -0,0 +1,241 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""Built-in tools.
+
+- :class:`PythonTool` -- run code in a :class:`Repl` (state persists
+ across calls in the same episode).
+- :class:`ShellTool` -- run argv inside a :class:`Sandbox`.
+- :class:`FileReadTool` -- read a file from inside a :class:`Sandbox`.
+- :class:`StopTool` -- explicit episode terminator. Raises
+ :class:`StopSignal` so the dispatcher can mark the episode done.
+"""
+from __future__ import annotations
+
+import shlex
+from collections.abc import Mapping
+from typing import Any, ClassVar
+
+from ..protocols import (
+ ParsedCall,
+ TextPart,
+ Tool,
+ ToolContext,
+ ToolError,
+ ToolResult,
+)
+from ..repl.base import Repl
+from ..sandbox.base import Sandbox
+
+
+class StopSignal(Exception):
+ """Raised by :class:`StopTool` to terminate the agent loop.
+
+ :class:`~torchrl.envs.llm.agentic.ToolCompose` catches this and sets
+ the corresponding episode-end flag in the step output.
+ """
+
+
+class PythonTool:
+ """Execute Python code in a stateful :class:`Repl`.
+
+ Args:
+ repl: The REPL backend. Must be opened (or used as a context
+ manager) before dispatch. ``ToolCompose`` opens/closes it
+ on env reset/close when it owns the repl.
+ timeout: Default per-call timeout (seconds). Per-call ``ctx`` may
+ override.
+ output_max_chars: Cap on the returned text (longer is truncated
+ with a marker).
+
+ Examples:
+ >>> from torchrl.envs.llm.agentic.sandbox import UnsafeSubprocessSandbox
+ >>> from torchrl.envs.llm.agentic.repl import SubprocessRepl
+ >>> tool = PythonTool(repl=SubprocessRepl(UnsafeSubprocessSandbox()))
+ """
+
+ name: ClassVar[str] = "python"
+ description: ClassVar[str] = "Execute Python code; state persists across calls."
+ input_schema: ClassVar[Mapping[str, Any]] = {
+ "type": "object",
+ "properties": {"code": {"type": "string"}},
+ "required": ["code"],
+ }
+ output_schema: ClassVar[Mapping[str, Any] | None] = None
+ wants_state: ClassVar[bool] = False
+
+ def __init__(
+ self,
+ repl: Repl,
+ *,
+ timeout: float | None = 30.0,
+ output_max_chars: int = 8192,
+ ) -> None:
+ self.repl = repl
+ self.timeout = timeout
+ self.output_max_chars = output_max_chars
+
+ async def setup(self) -> None:
+ await self.repl.open()
+
+ async def teardown(self) -> None:
+ await self.repl.close()
+
+ async def run(
+ self, args: Mapping[str, Any], ctx: ToolContext
+ ) -> ToolResult:
+ code = args.get("code", "")
+ if not isinstance(code, str):
+ raise ToolError("'code' must be a string")
+ result = await self.repl.execute(code, timeout=self.timeout)
+ text = result.text
+ truncated = False
+ if len(text) > self.output_max_chars:
+ text = text[: self.output_max_chars] + "\n... [truncated]"
+ truncated = True
+ return ToolResult(
+ parts=(TextPart(text=text),),
+ is_error=result.error is not None or result.timed_out,
+ meta={
+ "execution_count": result.execution_count,
+ "timed_out": result.timed_out,
+ "truncated": truncated,
+ },
+ )
+
+
+class ShellTool:
+ """Execute a shell command inside a :class:`Sandbox`.
+
+ Accepts either ``argv: list[str]`` or ``command: str``. ``command``
+ is split with :func:`shlex.split` -- callers needing pipes should
+ use ``argv=["sh", "-c", "..."]`` explicitly.
+ """
+
+ name: ClassVar[str] = "shell"
+ description: ClassVar[str] = "Execute a shell command in a sandbox."
+ input_schema: ClassVar[Mapping[str, Any]] = {
+ "type": "object",
+ "properties": {
+ "command": {"type": "string"},
+ "argv": {"type": "array", "items": {"type": "string"}},
+ "cwd": {"type": "string"},
+ },
+ }
+ output_schema: ClassVar[Mapping[str, Any] | None] = None
+ wants_state: ClassVar[bool] = False
+
+ def __init__(self, sandbox: Sandbox) -> None:
+ self.sandbox = sandbox
+
+ async def setup(self) -> None:
+ await self.sandbox.open()
+
+ async def teardown(self) -> None:
+ await self.sandbox.close()
+
+ async def run(
+ self, args: Mapping[str, Any], ctx: ToolContext
+ ) -> ToolResult:
+ argv = args.get("argv")
+ command = args.get("command")
+ if argv is None and command is None:
+ raise ToolError("ShellTool requires 'argv' or 'command'")
+ if argv is None:
+ argv = shlex.split(str(command))
+ result = await self.sandbox.run(list(argv), cwd=args.get("cwd"))
+ body_lines: list[str] = []
+ if result.stdout:
+ body_lines.append(result.stdout)
+ if result.stderr:
+ body_lines.append(f"[stderr]\n{result.stderr}")
+ body_lines.append(f"[exit {result.exit_code}]")
+ return ToolResult(
+ parts=(TextPart(text="\n".join(body_lines).rstrip()),),
+ is_error=result.exit_code != 0 or result.timed_out,
+ meta={
+ "exit_code": result.exit_code,
+ "timed_out": result.timed_out,
+ "wall_seconds": result.wall_seconds,
+ },
+ )
+
+
+class FileReadTool:
+ """Read a file from inside a :class:`Sandbox`."""
+
+ name: ClassVar[str] = "file_read"
+ description: ClassVar[str] = "Read a file from the sandbox filesystem."
+ input_schema: ClassVar[Mapping[str, Any]] = {
+ "type": "object",
+ "properties": {
+ "path": {"type": "string"},
+ "max_bytes": {"type": "integer"},
+ },
+ "required": ["path"],
+ }
+ output_schema: ClassVar[Mapping[str, Any] | None] = None
+ wants_state: ClassVar[bool] = False
+
+ def __init__(self, sandbox: Sandbox) -> None:
+ self.sandbox = sandbox
+
+ async def setup(self) -> None:
+ await self.sandbox.open()
+
+ async def teardown(self) -> None:
+ await self.sandbox.close()
+
+ async def run(
+ self, args: Mapping[str, Any], ctx: ToolContext
+ ) -> ToolResult:
+ path = args["path"]
+ max_bytes = args.get("max_bytes")
+ try:
+ data = await self.sandbox.read_file(path, max_bytes=max_bytes)
+ except FileNotFoundError as e:
+ raise ToolError(f"file not found: {path}") from e
+ return ToolResult.from_text(
+ data.decode("utf-8", errors="replace"),
+ meta={"bytes": len(data)},
+ )
+
+
+class StopTool:
+ """Zero-arg tool that ends the agent episode.
+
+ Raises :class:`StopSignal` from :meth:`run`. The dispatcher catches
+ this and sets the corresponding flag in the step output so the env
+ can terminate.
+ """
+
+ name: ClassVar[str] = "stop"
+ description: ClassVar[str] = "Signal that the agent has finished its task."
+ input_schema: ClassVar[Mapping[str, Any]] = {
+ "type": "object",
+ "properties": {"reason": {"type": "string"}},
+ }
+ output_schema: ClassVar[Mapping[str, Any] | None] = None
+ wants_state: ClassVar[bool] = False
+
+ async def setup(self) -> None:
+ pass
+
+ async def teardown(self) -> None:
+ pass
+
+ async def run(
+ self, args: Mapping[str, Any], ctx: ToolContext
+ ) -> ToolResult:
+ reason = str(args.get("reason", "done"))
+ raise StopSignal(reason)
+
+
+__all__ = [
+ "FileReadTool",
+ "PythonTool",
+ "ShellTool",
+ "StopSignal",
+ "StopTool",
+]
diff --git a/torchrl/envs/llm/agentic/tools/legacy_adapter.py b/torchrl/envs/llm/agentic/tools/legacy_adapter.py
new file mode 100644
index 00000000000..c8a5979c51a
--- /dev/null
+++ b/torchrl/envs/llm/agentic/tools/legacy_adapter.py
@@ -0,0 +1,175 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""Adapter lifting a legacy ``ToolTransformBase`` into a
+:class:`~torchrl.envs.llm.agentic.Tool`.
+
+Existing user code -- ``PythonInterpreter``, ``BrowserTransform``,
+``MCPToolTransform``, ``SimpleToolTransform`` -- can drop into
+:class:`~torchrl.envs.llm.agentic.ToolCompose` without rewriting.
+
+Bridge semantics
+~~~~~~~~~~~~~~~~
+
+The legacy classes implement ``_process_batch_item(content, index)`` which
+takes the *raw assistant message string* and returns a list of result
+strings. The agentic dispatcher already parsed the response and called us
+with ``args``, so the adapter synthesises a single-call message in the
+shape the legacy class expects (XML by default, configurable), runs the
+legacy ``_process_batch_item``, and returns the joined output string.
+
+This means the legacy class re-parses the synthesised message internally;
+that's the cost of bridging two different protocols. For tools that don't
+need this round-trip (i.e. anything new), write a native
+:class:`~torchrl.envs.llm.agentic.Tool` instead.
+"""
+from __future__ import annotations
+
+import asyncio
+import json
+from collections.abc import Mapping
+from typing import Any, Callable
+
+from ..protocols import TextPart, ToolContext, ToolError, ToolResult
+
+
+def _default_render(name: str, args: Mapping[str, Any]) -> str:
+ body = json.dumps(dict(args), ensure_ascii=False)
+ return f'{body}'
+
+
+def as_tool(
+ transform: Any,
+ *,
+ name: str,
+ input_schema: Mapping[str, Any] | None = None,
+ description: str = "",
+ wants_state: bool = False,
+ render_call: Callable[[str, Mapping[str, Any]], str] | None = None,
+):
+ """Wrap a legacy tool transform as a new-style
+ :class:`~torchrl.envs.llm.agentic.Tool`.
+
+ Args:
+ transform: An instance of a legacy
+ :class:`~torchrl.envs.llm.transforms.tools.ToolTransformBase`
+ subclass (or anything with a compatible
+ ``_process_batch_item(content: str, index: int) -> list[str] | None``).
+ name: Tool name to expose to the LLM. Should match the name the
+ legacy transform expects in the synthesised XML envelope (i.e.
+ the same string the model would write in
+ ````).
+ input_schema: JSON Schema dict for the LLM. If ``None``, a
+ permissive ``{"type": "object"}`` is used.
+ description: Tool description for the LLM.
+ wants_state: Set to ``True`` to receive a filtered TensorDict view
+ via ``ctx.state`` (mirrors the legacy ``pass_state_to_tools``
+ knob).
+ render_call: Custom function to format ``(name, args)`` into the
+ string the legacy transform parses. Defaults to
+ ``{json args}``.
+
+ Returns:
+ A :class:`Tool`-conforming object.
+
+ Examples:
+ >>> from torchrl.envs.llm.transforms import PythonInterpreter # doctest: +SKIP
+ >>> from torchrl.envs.llm.agentic.tools import as_tool
+ >>> tool = as_tool(
+ ... PythonInterpreter(persistent=True),
+ ... name="python",
+ ... input_schema={"type": "object",
+ ... "properties": {"code": {"type": "string"}},
+ ... "required": ["code"]},
+ ... )
+ """
+ return _LegacyToolAdapter(
+ transform,
+ name=name,
+ input_schema=input_schema or {"type": "object"},
+ description=description,
+ wants_state=wants_state,
+ render_call=render_call or _default_render,
+ )
+
+
+class _LegacyToolAdapter:
+ """Tool that delegates to a legacy ``ToolTransformBase``-style object."""
+
+ output_schema = None
+
+ def __init__(
+ self,
+ transform: Any,
+ *,
+ name: str,
+ input_schema: Mapping[str, Any],
+ description: str,
+ wants_state: bool,
+ render_call: Callable[[str, Mapping[str, Any]], str],
+ ) -> None:
+ # Per-instance attrs deliberately (rather than ClassVar) -- the
+ # adapter is a factory and each invocation produces a distinct
+ # tool with its own name/schema.
+ self.name = name
+ self.input_schema = dict(input_schema)
+ self.description = description
+ self.wants_state = wants_state
+ self._transform = transform
+ self._render = render_call
+
+ async def setup(self) -> None:
+ # Legacy transforms don't have a uniform setup hook; honor it if
+ # the duck-typed object provides one.
+ hook = getattr(self._transform, "setup", None)
+ if callable(hook):
+ res = hook()
+ if asyncio.iscoroutine(res):
+ await res
+
+ async def teardown(self) -> None:
+ hook = getattr(self._transform, "teardown", None)
+ if callable(hook):
+ res = hook()
+ if asyncio.iscoroutine(res):
+ await res
+ # The legacy PythonInterpreter has its own close/shutdown patterns;
+ # if exposed, call them.
+ for method_name in ("close", "shutdown"):
+ method = getattr(self._transform, method_name, None)
+ if callable(method):
+ try:
+ res = method()
+ if asyncio.iscoroutine(res):
+ await res
+ except Exception: # pragma: no cover -- defensive
+ pass
+ break
+
+ async def run(
+ self, args: Mapping[str, Any], ctx: ToolContext
+ ) -> ToolResult:
+ rendered = self._render(self.name, args)
+ process = getattr(self._transform, "_process_batch_item", None)
+ if not callable(process):
+ raise ToolError(
+ f"legacy transform {type(self._transform).__name__!r} has "
+ "no _process_batch_item; not adaptable"
+ )
+ # Legacy tools are sync; offload so we don't block the event loop.
+ try:
+ results = await asyncio.to_thread(process, rendered, 0)
+ except Exception as e: # pragma: no cover -- depends on legacy impl
+ raise ToolError(str(e)) from e
+ if not results:
+ return ToolResult.from_text("", meta={"adapter": "legacy"})
+ text = "\n".join(str(r) for r in results)
+ return ToolResult(
+ parts=(TextPart(text=text),),
+ is_error=False,
+ meta={"adapter": "legacy", "count": len(results)},
+ )
+
+
+__all__ = ["as_tool"]