diff --git a/benchmarks/test_llm.py b/benchmarks/test_llm.py index aeebbd62491..c729a4dc040 100644 --- a/benchmarks/test_llm.py +++ b/benchmarks/test_llm.py @@ -99,3 +99,99 @@ def setup(): warmup_rounds=3, setup=setup, ) + + +# ----- Agentic ToolCompose dispatch benchmarks ----- + + +class _SleepyTool: + """Bench-only tool: simulates a network/I/O call via asyncio.sleep.""" + + description = "sleep" + input_schema = { + "type": "object", + "properties": {"ms": {"type": "integer"}}, + } + output_schema = None + wants_state = False + + def __init__(self, name: str) -> None: + self.name = name + + async def setup(self) -> None: + pass + + async def teardown(self) -> None: + pass + + async def run(self, args, ctx): + import asyncio as _asyncio + + from torchrl.envs.llm.agentic import ToolResult + + await _asyncio.sleep(args.get("ms", 100) / 1000) + return ToolResult.from_text("ok") + + +@pytest.mark.benchmark(group="agentic-dispatch") +@pytest.mark.parametrize("n_tools", [3, 8]) +def test_toolcompose_parallel_dispatch(benchmark, n_tools): + """Bench parallel ToolCompose dispatch. + + With ``n_tools`` async tools each sleeping 50ms, parallel dispatch + should bottom out near 50ms regardless of ``n_tools``; serial + dispatch would scale linearly. + """ + from torchrl.envs import TransformedEnv + from torchrl.envs.llm import ChatEnv + from torchrl.envs.llm.agentic import ToolCompose + from torchrl.envs.llm.agentic.parsers import XMLToolCallParser + + set_list_to_stack(True).set() + base = ChatEnv(batch_size=(1,), input_mode="history") + tools = [_SleepyTool(f"t{i}") for i in range(n_tools)] + env = TransformedEnv( + base, ToolCompose(tools=tools, parser=XMLToolCallParser()) + ) + + fake = "".join( + f'{{"ms": 50}}' + for i in range(n_tools) + ) + + def go(): + obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,))) + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=fake).view(1, 1), dim=-1 + ) + env.step(obs) + + benchmark(go) + + +@pytest.mark.benchmark(group="agentic-dispatch") +def test_toolcompose_single_call_baseline(benchmark): + """One-call baseline so the n=3 / n=8 numbers are interpretable.""" + from torchrl.envs import TransformedEnv + from torchrl.envs.llm import ChatEnv + from torchrl.envs.llm.agentic import ToolCompose + from torchrl.envs.llm.agentic.parsers import XMLToolCallParser + + set_list_to_stack(True).set() + base = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + base, + ToolCompose( + tools=[_SleepyTool("t0")], parser=XMLToolCallParser() + ), + ) + fake = '{"ms": 50}' + + def go(): + obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,))) + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=fake).view(1, 1), dim=-1 + ) + env.step(obs) + + benchmark(go) diff --git a/test/llm/test_agentic.py b/test/llm/test_agentic.py index 6816ac56bd9..666707a0393 100644 --- a/test/llm/test_agentic.py +++ b/test/llm/test_agentic.py @@ -15,15 +15,26 @@ import asyncio import json import sys +import time as _time import warnings +from typing import ClassVar import pytest -from tensordict import set_list_to_stack +from tensordict import set_list_to_stack, TensorDict +from torchrl.data.llm import History +from torchrl.envs import TransformedEnv +from torchrl.envs.llm import ChatEnv from torchrl.envs.llm.agentic import ( ParsedCall, + PythonTool, + RateLimiter, + ShellTool, + StopTool, Tool, ToolCallParser, + ToolCompose, + ToolContext, ToolResult, validate_args, ) @@ -43,6 +54,8 @@ ) from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import _has_sandbox_exec +from torchrl.envs.llm.agentic.tools import as_tool +from torchrl.envs.llm.transforms import IncrementalTokenizer @pytest.fixture(scope="module", autouse=True) @@ -431,3 +444,365 @@ async def go(): assert r2.stdout.strip() == "42" _run(go()) + + +# ----- ToolCompose, builtins, legacy adapter ----- + + +class _Sleeper: + description: ClassVar[str] = "sleep N ms" + input_schema = { + "type": "object", + "properties": {"ms": {"type": "integer"}}, + } + output_schema = None + wants_state = False + + def __init__(self, name): + self.name = name + + async def setup(self): + pass + + async def teardown(self): + pass + + async def run(self, args, ctx): + await asyncio.sleep(args.get("ms", 100) / 1000) + return ToolResult.from_text(f"{self.name}-done") + + +class _Stateful: + name: ClassVar[str] = "stateful" + description: ClassVar[str] = "needs state" + input_schema = {"type": "object"} + output_schema = None + wants_state = True + received_state = None + + async def setup(self): + pass + + async def teardown(self): + pass + + async def run(self, args, ctx): + type(self).received_state = ctx.state + return ToolResult.from_text("ok") + + +class _Boom: + name: ClassVar[str] = "boom" + description = "always fails" + input_schema = {"type": "object"} + output_schema = None + wants_state = False + + async def setup(self): + pass + + async def teardown(self): + pass + + async def run(self, args, ctx): + raise RuntimeError("boom") + + +def _agentic_env(tools, parser=None): + parser = parser or XMLToolCallParser() + base = ChatEnv(batch_size=(1,), input_mode="history") + return TransformedEnv(base, ToolCompose(tools=tools, parser=parser)) + + +def _push_assistant(obs, response: str): + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=response).view(1, 1), dim=-1 + ) + + +class TestToolCompose: + def test_rejects_non_tool(self): + with pytest.raises(TypeError): + ToolCompose(tools=[object()], parser=XMLToolCallParser()) + + def test_rejects_duplicate_names(self): + with pytest.raises(ValueError): + ToolCompose( + tools=[_Sleeper("dup"), _Sleeper("dup")], + parser=XMLToolCallParser(), + ) + + def test_append_transform_blocked(self): + compose = ToolCompose( + tools=[StopTool()], parser=XMLToolCallParser() + ) + with pytest.raises(TypeError): + compose.append_transform(IncrementalTokenizer) + + def test_lookup_by_name(self): + compose = ToolCompose( + tools=[StopTool()], parser=XMLToolCallParser() + ) + assert "stop" in compose + assert compose["stop"].name == "stop" + + def test_parallel_dispatch_wall_time(self): + env = _agentic_env([_Sleeper("a"), _Sleeper("b"), _Sleeper("c")]) + obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,))) + _push_assistant( + obs, + '{"ms": 500}' + '{"ms": 500}' + '{"ms": 500}', + ) + t0 = _time.monotonic() + nxt = env.step(obs) + elapsed = _time.monotonic() - t0 + # Three 500ms tools must run concurrently: total < 0.8s. + assert elapsed < 0.9, ( + f"parallel dispatch took {elapsed:.2f}s; expected < 0.8s" + ) + assert bool(nxt.get(("next", "agentic", "any_tool_calls")).item()) + assert not bool(nxt.get(("next", "agentic", "stop_requested")).item()) + + def test_stop_tool_terminates(self): + env = _agentic_env([StopTool()]) + obs = env.reset(TensorDict({"query": "stop"}, batch_size=(1,))) + _push_assistant(obs, '{"reason":"done"}') + nxt = env.step(obs) + assert bool(nxt.get(("next", "agentic", "stop_requested")).item()) + + def test_no_tool_calls_passthrough(self): + env = _agentic_env([StopTool()]) + obs = env.reset(TensorDict({"query": "nothing"}, batch_size=(1,))) + _push_assistant(obs, "I have nothing to call.") + nxt = env.step(obs) + assert not bool(nxt.get(("next", "agentic", "any_tool_calls")).item()) + + def test_unknown_tool_reports_error(self): + env = _agentic_env([StopTool()]) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') + nxt = env.step(obs) + assert bool(nxt.get(("next", "agentic", "any_error")).item()) + + def test_failure_isolation(self): + env = _agentic_env([_Sleeper("a"), _Boom()]) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant( + obs, + '{"ms": 5}' + '{}', + ) + nxt = env.step(obs) + # both calls fired; one failed, one succeeded. + assert bool(nxt.get(("next", "agentic", "any_error")).item()) + assert bool(nxt.get(("next", "agentic", "any_tool_calls")).item()) + # The history should have both tool messages appended. + prompt = nxt[("next", "history")].prompt + # Assistant message + 2 tool messages = at least 3 entries beyond + # the original prompt. + assert len(prompt[0]) >= 3 + + def test_stable_call_id_round_trip(self): + captured: list[str] = [] + + class _Recorder: + name: ClassVar[str] = "rec" + description = "records call_id" + input_schema = {"type": "object"} + output_schema = None + wants_state = False + + async def setup(self): + pass + + async def teardown(self): + pass + + async def run(self, args, ctx): + captured.append(ctx.call_id) + return ToolResult.from_text(f"id={ctx.call_id}") + + env = _agentic_env([_Recorder()]) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') + nxt = env.step(obs) + assert captured == ["my-id"] + # The rendered tool message must reference the same call_id. + prompt = nxt[("next", "history")].prompt + last_msg = prompt[0][-1] + assert "my-id" in last_msg.content + + def test_pass_state_to_tools(self): + _Stateful.received_state = None + base = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + base, + ToolCompose( + tools=[_Stateful()], + parser=XMLToolCallParser(), + pass_state_to_tools=True, + ), + ) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') + env.step(obs) + assert _Stateful.received_state is not None + + def test_pass_state_off_means_no_state(self): + _Stateful.received_state = None + env = _agentic_env([_Stateful()]) # pass_state_to_tools defaults False + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') + env.step(obs) + assert _Stateful.received_state is None + + def test_rate_limit_serializes_concurrent_calls(self): + # max_concurrent=1 forces 3 calls of 200ms each to take >= 600ms. + slow = _Sleeper("slow") + compose = ToolCompose( + tools=[slow], + parser=XMLToolCallParser(), + rate_limits={"slow": RateLimiter(max_concurrent=1)}, + ) + base = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv(base, compose) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant( + obs, + '{"ms": 200}' + '{"ms": 200}' + '{"ms": 200}', + ) + t0 = _time.monotonic() + env.step(obs) + elapsed = _time.monotonic() - t0 + assert elapsed >= 0.55, ( + f"rate-limited dispatch should serialize: got {elapsed:.2f}s" + ) + + def test_argument_validation(self): + class _NeedsCode: + name: ClassVar[str] = "needs" + description = "" + input_schema = { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + } + output_schema = None + wants_state = False + + async def setup(self): + pass + + async def teardown(self): + pass + + async def run(self, args, ctx): # pragma: no cover - never reached + return ToolResult.from_text("hit") + + env = _agentic_env([_NeedsCode()]) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') # missing required + nxt = env.step(obs) + assert bool(nxt.get(("next", "agentic", "any_error")).item()) + + def test_nested_loop_safety(self): + # When the caller already owns an event loop, ToolCompose._step + # must still complete (offload to a worker thread). + async def go(): + env = _agentic_env([StopTool()]) + obs = env.reset(TensorDict({"query": "?"}, batch_size=(1,))) + _push_assistant(obs, '{}') + return env.step(obs) + + nxt = _run(go()) + assert bool(nxt.get(("next", "agentic", "stop_requested")).item()) + + +class TestPythonTool: + def test_state_persists_across_calls(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + tool = PythonTool(repl=SubprocessRepl(s)) + await tool.setup() + ctx = ToolContext(call_id="c1") + r1 = await tool.run({"code": "x = 41"}, ctx) + assert not r1.is_error + r2 = await tool.run({"code": "print(x + 1)"}, ctx) + assert not r2.is_error + assert "42" in r2.text + await tool.teardown() + + _run(go()) + + def test_error_marked_is_error(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + tool = PythonTool(repl=SubprocessRepl(s)) + await tool.setup() + r = await tool.run({"code": "1/0"}, ToolContext(call_id="c")) + assert r.is_error + assert "ZeroDivisionError" in r.text + await tool.teardown() + + _run(go()) + + +class TestShellTool: + def test_runs_argv(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=5) + ) as s: + tool = ShellTool(s) + await tool.setup() + r = await tool.run( + {"argv": ["/bin/echo", "hi"]}, ToolContext(call_id="c") + ) + assert not r.is_error + assert "hi" in r.text + # Don't tear down s twice -- ShellTool teardown closes it. + + _run(go()) + + +class TestLegacyAdapter: + def test_lifts_legacy_transform(self): + # Use a tiny duck-typed legacy class instead of pulling in the + # full PythonInterpreter, which has its own subprocess pool. + class _LegacyAdder: + tool_role = "tool" + + def _process_batch_item(self, content, index): + # Echo the captured XML so the assertion can find it. + if "{"a": 1, "b": 2}') + nxt = env.step(obs) + prompt = nxt[("next", "history")].prompt + # The last appended message should be the tool result containing + # the legacy output. + assert "legacy got" in prompt[0][-1].content diff --git a/torchrl/envs/llm/agentic/__init__.py b/torchrl/envs/llm/agentic/__init__.py index 7bdea7108dc..a18f6a6f698 100644 --- a/torchrl/envs/llm/agentic/__init__.py +++ b/torchrl/envs/llm/agentic/__init__.py @@ -16,6 +16,7 @@ """ from __future__ import annotations +from .compose import DispatchResult, ToolCompose from .protocols import ( FileRefPart, ImagePart, @@ -30,21 +31,39 @@ ToolResult, ToolResultPart, ) +from .rate_limit import RateLimiter from .schema import json_schema_from_pydantic, validate_args +from .tools import ( + FileReadTool, + PythonTool, + ShellTool, + StopSignal, + StopTool, + as_tool, +) __all__ = [ + "DispatchResult", + "FileReadTool", "FileRefPart", "ImagePart", "JsonPart", "ParseResult", "ParsedCall", + "PythonTool", + "RateLimiter", + "ShellTool", + "StopSignal", + "StopTool", "TextPart", "Tool", "ToolCallParser", + "ToolCompose", "ToolContext", "ToolError", "ToolResult", "ToolResultPart", + "as_tool", "json_schema_from_pydantic", "validate_args", ] diff --git a/torchrl/envs/llm/agentic/compose.py b/torchrl/envs/llm/agentic/compose.py new file mode 100644 index 00000000000..5b34c8a2b92 --- /dev/null +++ b/torchrl/envs/llm/agentic/compose.py @@ -0,0 +1,519 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""ToolCompose: parallel async tool dispatch over a ChatEnv history. + +A :class:`ToolCompose` is a :class:`~torchrl.envs.transforms.Compose` +subclass that: + +1. owns a :class:`~torchrl.envs.llm.agentic.ToolCallParser`, +2. holds a fixed set of :class:`~torchrl.envs.llm.agentic.Tool` instances + (raises :class:`TypeError` on non-Tool insert), +3. on each step parses the latest assistant message *once*, dispatches + matched tools concurrently via :func:`asyncio.gather`, renders each + result through the parser, and extends the + :class:`~torchrl.data.llm.History` with the resulting tool messages, +4. surfaces ``("agentic", "any_tool_calls")`` and + ``("agentic", "stop_requested")`` keys in the step output for the env + to use as termination signals. + +The ChatEnv is unchanged; ``ToolCompose`` lives entirely in transform +space. +""" +from __future__ import annotations + +import asyncio +import json +import threading +import uuid +from collections.abc import Iterable, Mapping +from concurrent.futures import Future +from dataclasses import dataclass, field +from typing import Any + +import torch +from tensordict import lazy_stack, TensorDictBase +from torchrl._utils import logger as torchrl_logger +from torchrl.data.llm import History +from torchrl.envs.transforms import Compose, Transform + +from .protocols import ( + ParsedCall, + ParseResult, + Tool, + ToolCallParser, + ToolContext, + ToolError, + ToolResult, +) +from .rate_limit import RateLimiter +from .schema import validate_args +from .tools.builtin import StopSignal + + +@dataclass +class DispatchResult: + """Aggregate outcome of one :meth:`ToolCompose._dispatch_one` call. + + Attributes: + cleaned_text: Assistant text with tool-call syntax stripped. + calls: Parsed calls in emission order. + results: Tool results, aligned with ``calls``. + any_error: ``True`` if any tool failed (including unknown tool). + stop_requested: ``True`` if any tool raised :class:`StopSignal`. + """ + + cleaned_text: str = "" + calls: tuple[ParsedCall, ...] = () + results: tuple[ToolResult, ...] = () + any_error: bool = False + stop_requested: bool = False + + +class _ToolTransformShim(Transform): + """Thin :class:`Transform` wrapper around a :class:`Tool`. + + Stored inside :class:`ToolCompose.transforms` so the parent ``Compose`` + machinery (module hierarchy, device moves, training-mode propagation) + keeps working. The shim has no ``_step`` of its own -- ``ToolCompose`` + drives dispatch directly and consults this shim only as a holder. + """ + + def __init__(self, tool: Tool) -> None: + super().__init__() + # Do NOT register the tool as a submodule -- tools are not nn.Module. + # Storing in __dict__ keeps it out of state_dict. + object.__setattr__(self, "tool", tool) + + @property + def name(self) -> str: + return self.tool.name + + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: # noqa: D401 + # ToolCompose drives dispatch directly; the shim is a pass-through. + return next_tensordict + + +class ToolCompose(Compose): + """A :class:`Compose` of :class:`~torchrl.envs.llm.agentic.Tool` objects + with parallel async dispatch. + + Args: + tools: The set of tools the LLM may call. Each must conform to the + :class:`~torchrl.envs.llm.agentic.Tool` protocol (``name``, + ``input_schema``, async ``run``). Inserting a non-Tool (e.g. a + plain :class:`Transform`) raises :class:`TypeError`. + parser: The :class:`~torchrl.envs.llm.agentic.ToolCallParser` used + to extract calls from each assistant message and render + results back as tool messages. + rate_limits: Optional per-tool :class:`RateLimiter` map (keyed by + tool name). Tools without an entry are unthrottled. + per_call_timeout: Default per-call timeout in seconds. ``None`` + means rely on tool/repl timeouts only. + pass_state_to_tools: If ``True`` and a tool has + ``wants_state=True``, ``ctx.state`` is populated with a + filtered read-only view of the env tensordict (mirrors the + legacy ``ExecuteToolsInOrder`` knob). + tool_role: Role string used when injecting tool messages into + :class:`~torchrl.data.llm.History` (default ``"tool"``). + validate_inputs: If ``True`` (default), validate ``args`` against + ``tool.input_schema`` before dispatch. Schema mismatches are + reported as :class:`ToolResult` with ``is_error=True``. + + Example: + >>> from torchrl.envs import TransformedEnv + >>> from torchrl.envs.llm import ChatEnv # doctest: +SKIP + >>> from torchrl.envs.llm.agentic import ToolCompose # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.parsers import XMLToolCallParser + >>> from torchrl.envs.llm.agentic.tools import StopTool + >>> compose = ToolCompose( + ... tools=[StopTool()], + ... parser=XMLToolCallParser(), + ... ) + """ + + def __init__( + self, + *, + tools: Iterable[Tool], + parser: ToolCallParser, + rate_limits: Mapping[str, RateLimiter] | None = None, + per_call_timeout: float | None = None, + pass_state_to_tools: bool = False, + tool_role: str = "tool", + validate_inputs: bool = True, + ) -> None: + tool_list = list(tools) + for t in tool_list: + if not _is_tool(t): + raise TypeError( + f"ToolCompose accepts Tool objects only; got " + f"{type(t).__name__!r}. Wrap a legacy transform with " + "torchrl.envs.llm.agentic.tools.as_tool(...)." + ) + if not isinstance(parser, ToolCallParser): + raise TypeError( + "parser must implement ToolCallParser; got " + f"{type(parser).__name__!r}" + ) + names = [t.name for t in tool_list] + if len(names) != len(set(names)): + seen: set[str] = set() + dups = [n for n in names if n in seen or seen.add(n)] # type: ignore[func-returns-value] + raise ValueError(f"duplicate tool names: {dups!r}") + shims = [_ToolTransformShim(t) for t in tool_list] + super().__init__(*shims) + self._tool_list: list[Tool] = tool_list + self._tools_by_name: dict[str, Tool] = {t.name: t for t in tool_list} + self.parser = parser + self._rate_limits: dict[str, RateLimiter] = dict(rate_limits or {}) + self._per_call_timeout = per_call_timeout + self._pass_state = pass_state_to_tools + self._tool_role = tool_role + self._validate_inputs = validate_inputs + self._setup_done = False + + # ----- introspection helpers ----- + + @property + def tools(self) -> tuple[Tool, ...]: + return tuple(self._tool_list) + + def __getitem__(self, key: str | int): # type: ignore[override] + if isinstance(key, str): + return self._tools_by_name[key] + return super().__getitem__(key) + + def __contains__(self, key: object) -> bool: # type: ignore[override] + if isinstance(key, str): + return key in self._tools_by_name + return super().__contains__(key) + + # ----- enforce Tool-only insertion ----- + + def append_transform(self, transform): # type: ignore[override] + # ``Compose.append_transform`` accepts arbitrary Transforms. We + # constrain to Tools to keep the dispatch invariants intact. + raise TypeError( + "ToolCompose does not accept arbitrary transforms. Use " + "append_tool(tool) instead." + ) + + def append_tool(self, tool: Tool) -> None: + """Append a :class:`Tool` to the dispatch set.""" + if not _is_tool(tool): + raise TypeError( + f"append_tool requires a Tool; got {type(tool).__name__!r}" + ) + if tool.name in self._tools_by_name: + raise ValueError(f"duplicate tool name: {tool.name!r}") + shim = _ToolTransformShim(tool) + self.transforms.append(shim) + shim.set_container(self) + self._tool_list.append(tool) + self._tools_by_name[tool.name] = tool + + # ----- lifecycle ----- + + async def _setup_tools(self) -> None: + if self._setup_done: + return + for tool in self._tool_list: + try: + await tool.setup() + except Exception: # pragma: no cover -- per-tool setup is best-effort + torchrl_logger.exception( + "tool %r setup raised; continuing", tool.name + ) + self._setup_done = True + + async def _teardown_tools(self) -> None: + for tool in self._tool_list: + try: + await tool.teardown() + except Exception: # pragma: no cover + torchrl_logger.exception( + "tool %r teardown raised; continuing", tool.name + ) + self._setup_done = False + + def close(self) -> None: # type: ignore[override] + super().close() + try: + _run_async(self._teardown_tools()) + except Exception: # pragma: no cover + torchrl_logger.exception("teardown_tools raised; continuing") + + # ----- the step path ----- + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + if next_tensordict.batch_dims > 1: + with next_tensordict.view(-1) as flat: + flat = self._step(tensordict, flat) + return next_tensordict + history = self._extract_history(next_tensordict) + last = history[..., -1] + contents = last.content + if isinstance(contents, str): + contents = [contents] + # asyncio.gather across batch items. + dispatch_results: list[DispatchResult] = _run_async( + self._dispatch_batch(list(contents), next_tensordict) + ) + return self._inject_results(history, dispatch_results, next_tensordict) + + def _extract_history(self, td: TensorDictBase) -> History: + parent = self.parent + if parent is None: + raise RuntimeError( + "ToolCompose must be used inside a TransformedEnv" + ) + base_env = parent.base_env + if getattr(base_env, "input_mode", None) != "history": + raise RuntimeError( + "ToolCompose requires the underlying ChatEnv to use " + "input_mode='history' (got " + f"{getattr(base_env, 'input_mode', None)!r})" + ) + return td["history"].prompt + + async def _dispatch_batch( + self, contents: list[str], td: TensorDictBase + ) -> list[DispatchResult]: + await self._setup_tools() + # Each batch item runs concurrently with every other; within a + # batch item, calls also run concurrently. + return list( + await asyncio.gather( + *( + self._dispatch_one(content, td, batch_index=i) + for i, content in enumerate(contents) + ) + ) + ) + + async def _dispatch_one( + self, content: str, td: TensorDictBase, *, batch_index: int + ) -> DispatchResult: + parsed: ParseResult = self.parser.parse(content) + if not parsed.calls: + return DispatchResult(cleaned_text=parsed.text) + # Parallel run all calls in this batch item. + coros = [ + self._run_one(call, td, batch_index=batch_index) + for call in parsed.calls + ] + outcomes = await asyncio.gather(*coros, return_exceptions=False) + results: list[ToolResult] = [] + any_error = False + stop = False + for r in outcomes: + if isinstance(r, _StopMarker): + results.append(r.result) + stop = True + else: + results.append(r) + if r.is_error: + any_error = True + return DispatchResult( + cleaned_text=parsed.text, + calls=parsed.calls, + results=tuple(results), + any_error=any_error, + stop_requested=stop, + ) + + async def _run_one( + self, call: ParsedCall, td: TensorDictBase, *, batch_index: int + ): + tool = self._tools_by_name.get(call.tool) + if tool is None: + return ToolResult.from_text( + f"unknown tool: {call.tool!r}", + is_error=True, + meta={"call_id": call.call_id}, + ) + if self._validate_inputs: + try: + validate_args(call.args, tool.input_schema) + except Exception as e: + return ToolResult.from_text( + f"argument validation failed for {tool.name!r}: {e}", + is_error=True, + meta={"call_id": call.call_id}, + ) + ctx = ToolContext( + call_id=call.call_id, + tag=call.tag, + state=self._filter_state(td, batch_index) if self._pass_state and getattr( + tool, "wants_state", False + ) else None, + sandbox=getattr(tool, "sandbox", None), + repl=getattr(tool, "repl", None), + compose=self, + ) + limiter = self._rate_limits.get(tool.name) + timeout = self._per_call_timeout + try: + if limiter is not None: + async with limiter.slot(): + coro = tool.run(call.args, ctx) + if timeout is None: + return await coro + return await asyncio.wait_for(coro, timeout=timeout) + coro = tool.run(call.args, ctx) + if timeout is None: + return await coro + return await asyncio.wait_for(coro, timeout=timeout) + except StopSignal as s: + return _StopMarker( + ToolResult.from_text( + f"[stop] {s}", meta={"call_id": call.call_id, "stop": True} + ) + ) + except ToolError as e: + return ToolResult.from_text( + str(e), is_error=True, meta={"call_id": call.call_id} + ) + except asyncio.TimeoutError: + return ToolResult.from_text( + f"tool {tool.name!r} timed out after {timeout}s", + is_error=True, + meta={"call_id": call.call_id, "timed_out": True}, + ) + except Exception as e: # noqa: BLE001 -- failure isolation + torchrl_logger.exception( + "tool %r raised; reporting as error", tool.name + ) + return ToolResult.from_text( + f"{type(e).__name__}: {e}", + is_error=True, + meta={"call_id": call.call_id}, + ) + + def _filter_state( + self, td: TensorDictBase, batch_index: int + ) -> TensorDictBase | None: + try: + view = td[batch_index] if td.batch_dims else td + except Exception: # pragma: no cover + return None + try: + return view.exclude("history") + except Exception: # pragma: no cover + return view + + # ----- result injection ----- + + def _inject_results( + self, + history: History, + dispatches: list[DispatchResult], + td: TensorDictBase, + ) -> TensorDictBase: + # Build per-batch-item History extension lists and the agentic + # signals. + any_calls = [bool(d.calls) for d in dispatches] + any_error = [d.any_error for d in dispatches] + any_stop = [d.stop_requested for d in dispatches] + + per_item_messages: list[list[History]] = [] + for d in dispatches: + messages: list[History] = [] + for call, result in zip(d.calls, d.results): + rendered = self.parser.render_result(call.call_id, result) + content = rendered.get("content", "") + if isinstance(content, list): + content = json.dumps(content, ensure_ascii=False) + messages.append( + History(role=self._tool_role, content=str(content)) + ) + per_item_messages.append(messages) + + max_len = max((len(m) for m in per_item_messages), default=0) + if max_len > 0: + padded: list[list[History]] = [] + for messages in per_item_messages: + if len(messages) == max_len: + padded.append(messages) + else: + padded.append( + messages + + [History(role="", content="")] + * (max_len - len(messages)) + ) + stacked = lazy_stack([lazy_stack(m) for m in padded]) + history.extend(stacked, dim=-1) + td["history"].prompt = history + + device = td.device + td.set( + ("agentic", "any_tool_calls"), + torch.tensor(any_calls, dtype=torch.bool, device=device), + ) + td.set( + ("agentic", "any_error"), + torch.tensor(any_error, dtype=torch.bool, device=device), + ) + td.set( + ("agentic", "stop_requested"), + torch.tensor(any_stop, dtype=torch.bool, device=device), + ) + return td + + +@dataclass +class _StopMarker: + """Internal carrier flagging a :class:`StopSignal` from a tool.""" + + result: ToolResult + + +def _is_tool(obj: Any) -> bool: + """Duck-typed Tool conformance check. + + ``Tool`` is a runtime-checkable Protocol but Protocol checks treat + every required attribute as needing to be present at the *instance* + level; class-level ``ClassVar``s satisfy this in CPython but the check + is unreliable across versions when adapters set per-instance ``name``. + Reuse a simpler attribute-presence test. + """ + required = ("name", "input_schema", "run", "setup", "teardown") + return all(hasattr(obj, a) for a in required) and callable(obj.run) + + +# ----- async runner: nested-loop safe ----- + +def _run_async(coro): + """Run ``coro`` to completion regardless of whether the caller is + inside an event loop. + + - No running loop: :func:`asyncio.run`. + - Running loop: dispatch on a worker thread that owns its own loop + and join. (Necessary because Compose._step is sync and may be + called from inside a Jupyter-style outer loop.) + """ + try: + asyncio.get_running_loop() + running = True + except RuntimeError: + running = False + if not running: + return asyncio.run(coro) + fut: Future = Future() + + def _target(): + try: + fut.set_result(asyncio.run(coro)) + except BaseException as e: # noqa: BLE001 + fut.set_exception(e) + + t = threading.Thread(target=_target, daemon=True) + t.start() + return fut.result() + + +__all__ = ["DispatchResult", "ToolCompose"] diff --git a/torchrl/envs/llm/agentic/rate_limit.py b/torchrl/envs/llm/agentic/rate_limit.py new file mode 100644 index 00000000000..f463be9b1e3 --- /dev/null +++ b/torchrl/envs/llm/agentic/rate_limit.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Asyncio rate limiter -- semaphore + token bucket. + +A :class:`RateLimiter` caps concurrent in-flight calls and (optionally) the +sustained call rate. Used by :class:`~torchrl.envs.llm.agentic.ToolCompose` +to throttle individual tools (e.g. a search API at 5 QPS) without blocking +the rest of the dispatch. +""" +from __future__ import annotations + +import asyncio +import time +from contextlib import asynccontextmanager + + +class RateLimiter: + """Combined semaphore + token-bucket throttle. + + Args: + max_concurrent: Cap on simultaneously in-flight calls. ``None`` for + unlimited. + rate_per_second: Sustained refill rate. ``None`` disables the + token bucket; only the semaphore is enforced. + burst: Token-bucket capacity. Defaults to ``rate_per_second``. + + Examples: + >>> import asyncio + >>> async def go(): + ... limiter = RateLimiter(max_concurrent=2) + ... async with limiter.slot(): + ... pass + >>> asyncio.run(go()) + """ + + def __init__( + self, + *, + max_concurrent: int | None = None, + rate_per_second: float | None = None, + burst: float | None = None, + ) -> None: + self._sem: asyncio.Semaphore | None = ( + asyncio.Semaphore(max_concurrent) if max_concurrent else None + ) + self._rate = rate_per_second + self._capacity = burst if burst is not None else (rate_per_second or 0.0) + self._tokens = self._capacity + self._last = time.monotonic() + self._lock = asyncio.Lock() + + async def _consume(self) -> None: + if not self._rate: + return + async with self._lock: + now = time.monotonic() + self._tokens = min( + self._capacity, self._tokens + (now - self._last) * self._rate + ) + self._last = now + if self._tokens >= 1.0: + self._tokens -= 1.0 + return + wait = (1.0 - self._tokens) / self._rate + await asyncio.sleep(wait) + async with self._lock: + self._tokens = max(0.0, self._tokens - 1.0) + + @asynccontextmanager + async def slot(self): + """Acquire one slot. Blocks until both the semaphore and token + bucket allow.""" + await self._consume() + if self._sem is None: + yield + return + async with self._sem: + yield + + +__all__ = ["RateLimiter"] diff --git a/torchrl/envs/llm/agentic/tools/__init__.py b/torchrl/envs/llm/agentic/tools/__init__.py new file mode 100644 index 00000000000..2795b09c4fa --- /dev/null +++ b/torchrl/envs/llm/agentic/tools/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Built-in tools and the legacy-transform adapter. + +- :class:`PythonTool`, :class:`ShellTool`, :class:`FileReadTool`, + :class:`StopTool` -- ready-to-use tools wired to the agentic + Sandbox/Repl primitives. +- :func:`as_tool` -- adapter lifting any legacy + :class:`~torchrl.envs.llm.transforms.tools.ToolTransformBase` subclass + into a :class:`~torchrl.envs.llm.agentic.Tool` so existing user code + keeps working inside :class:`~torchrl.envs.llm.agentic.ToolCompose`. +""" +from __future__ import annotations + +from .builtin import FileReadTool, PythonTool, ShellTool, StopTool, StopSignal +from .legacy_adapter import as_tool + +__all__ = [ + "FileReadTool", + "PythonTool", + "ShellTool", + "StopSignal", + "StopTool", + "as_tool", +] diff --git a/torchrl/envs/llm/agentic/tools/builtin.py b/torchrl/envs/llm/agentic/tools/builtin.py new file mode 100644 index 00000000000..f0c7dc688ea --- /dev/null +++ b/torchrl/envs/llm/agentic/tools/builtin.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Built-in tools. + +- :class:`PythonTool` -- run code in a :class:`Repl` (state persists + across calls in the same episode). +- :class:`ShellTool` -- run argv inside a :class:`Sandbox`. +- :class:`FileReadTool` -- read a file from inside a :class:`Sandbox`. +- :class:`StopTool` -- explicit episode terminator. Raises + :class:`StopSignal` so the dispatcher can mark the episode done. +""" +from __future__ import annotations + +import shlex +from collections.abc import Mapping +from typing import Any, ClassVar + +from ..protocols import ( + ParsedCall, + TextPart, + Tool, + ToolContext, + ToolError, + ToolResult, +) +from ..repl.base import Repl +from ..sandbox.base import Sandbox + + +class StopSignal(Exception): + """Raised by :class:`StopTool` to terminate the agent loop. + + :class:`~torchrl.envs.llm.agentic.ToolCompose` catches this and sets + the corresponding episode-end flag in the step output. + """ + + +class PythonTool: + """Execute Python code in a stateful :class:`Repl`. + + Args: + repl: The REPL backend. Must be opened (or used as a context + manager) before dispatch. ``ToolCompose`` opens/closes it + on env reset/close when it owns the repl. + timeout: Default per-call timeout (seconds). Per-call ``ctx`` may + override. + output_max_chars: Cap on the returned text (longer is truncated + with a marker). + + Examples: + >>> from torchrl.envs.llm.agentic.sandbox import UnsafeSubprocessSandbox + >>> from torchrl.envs.llm.agentic.repl import SubprocessRepl + >>> tool = PythonTool(repl=SubprocessRepl(UnsafeSubprocessSandbox())) + """ + + name: ClassVar[str] = "python" + description: ClassVar[str] = "Execute Python code; state persists across calls." + input_schema: ClassVar[Mapping[str, Any]] = { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + } + output_schema: ClassVar[Mapping[str, Any] | None] = None + wants_state: ClassVar[bool] = False + + def __init__( + self, + repl: Repl, + *, + timeout: float | None = 30.0, + output_max_chars: int = 8192, + ) -> None: + self.repl = repl + self.timeout = timeout + self.output_max_chars = output_max_chars + + async def setup(self) -> None: + await self.repl.open() + + async def teardown(self) -> None: + await self.repl.close() + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + code = args.get("code", "") + if not isinstance(code, str): + raise ToolError("'code' must be a string") + result = await self.repl.execute(code, timeout=self.timeout) + text = result.text + truncated = False + if len(text) > self.output_max_chars: + text = text[: self.output_max_chars] + "\n... [truncated]" + truncated = True + return ToolResult( + parts=(TextPart(text=text),), + is_error=result.error is not None or result.timed_out, + meta={ + "execution_count": result.execution_count, + "timed_out": result.timed_out, + "truncated": truncated, + }, + ) + + +class ShellTool: + """Execute a shell command inside a :class:`Sandbox`. + + Accepts either ``argv: list[str]`` or ``command: str``. ``command`` + is split with :func:`shlex.split` -- callers needing pipes should + use ``argv=["sh", "-c", "..."]`` explicitly. + """ + + name: ClassVar[str] = "shell" + description: ClassVar[str] = "Execute a shell command in a sandbox." + input_schema: ClassVar[Mapping[str, Any]] = { + "type": "object", + "properties": { + "command": {"type": "string"}, + "argv": {"type": "array", "items": {"type": "string"}}, + "cwd": {"type": "string"}, + }, + } + output_schema: ClassVar[Mapping[str, Any] | None] = None + wants_state: ClassVar[bool] = False + + def __init__(self, sandbox: Sandbox) -> None: + self.sandbox = sandbox + + async def setup(self) -> None: + await self.sandbox.open() + + async def teardown(self) -> None: + await self.sandbox.close() + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + argv = args.get("argv") + command = args.get("command") + if argv is None and command is None: + raise ToolError("ShellTool requires 'argv' or 'command'") + if argv is None: + argv = shlex.split(str(command)) + result = await self.sandbox.run(list(argv), cwd=args.get("cwd")) + body_lines: list[str] = [] + if result.stdout: + body_lines.append(result.stdout) + if result.stderr: + body_lines.append(f"[stderr]\n{result.stderr}") + body_lines.append(f"[exit {result.exit_code}]") + return ToolResult( + parts=(TextPart(text="\n".join(body_lines).rstrip()),), + is_error=result.exit_code != 0 or result.timed_out, + meta={ + "exit_code": result.exit_code, + "timed_out": result.timed_out, + "wall_seconds": result.wall_seconds, + }, + ) + + +class FileReadTool: + """Read a file from inside a :class:`Sandbox`.""" + + name: ClassVar[str] = "file_read" + description: ClassVar[str] = "Read a file from the sandbox filesystem." + input_schema: ClassVar[Mapping[str, Any]] = { + "type": "object", + "properties": { + "path": {"type": "string"}, + "max_bytes": {"type": "integer"}, + }, + "required": ["path"], + } + output_schema: ClassVar[Mapping[str, Any] | None] = None + wants_state: ClassVar[bool] = False + + def __init__(self, sandbox: Sandbox) -> None: + self.sandbox = sandbox + + async def setup(self) -> None: + await self.sandbox.open() + + async def teardown(self) -> None: + await self.sandbox.close() + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + path = args["path"] + max_bytes = args.get("max_bytes") + try: + data = await self.sandbox.read_file(path, max_bytes=max_bytes) + except FileNotFoundError as e: + raise ToolError(f"file not found: {path}") from e + return ToolResult.from_text( + data.decode("utf-8", errors="replace"), + meta={"bytes": len(data)}, + ) + + +class StopTool: + """Zero-arg tool that ends the agent episode. + + Raises :class:`StopSignal` from :meth:`run`. The dispatcher catches + this and sets the corresponding flag in the step output so the env + can terminate. + """ + + name: ClassVar[str] = "stop" + description: ClassVar[str] = "Signal that the agent has finished its task." + input_schema: ClassVar[Mapping[str, Any]] = { + "type": "object", + "properties": {"reason": {"type": "string"}}, + } + output_schema: ClassVar[Mapping[str, Any] | None] = None + wants_state: ClassVar[bool] = False + + async def setup(self) -> None: + pass + + async def teardown(self) -> None: + pass + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + reason = str(args.get("reason", "done")) + raise StopSignal(reason) + + +__all__ = [ + "FileReadTool", + "PythonTool", + "ShellTool", + "StopSignal", + "StopTool", +] diff --git a/torchrl/envs/llm/agentic/tools/legacy_adapter.py b/torchrl/envs/llm/agentic/tools/legacy_adapter.py new file mode 100644 index 00000000000..c8a5979c51a --- /dev/null +++ b/torchrl/envs/llm/agentic/tools/legacy_adapter.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Adapter lifting a legacy ``ToolTransformBase`` into a +:class:`~torchrl.envs.llm.agentic.Tool`. + +Existing user code -- ``PythonInterpreter``, ``BrowserTransform``, +``MCPToolTransform``, ``SimpleToolTransform`` -- can drop into +:class:`~torchrl.envs.llm.agentic.ToolCompose` without rewriting. + +Bridge semantics +~~~~~~~~~~~~~~~~ + +The legacy classes implement ``_process_batch_item(content, index)`` which +takes the *raw assistant message string* and returns a list of result +strings. The agentic dispatcher already parsed the response and called us +with ``args``, so the adapter synthesises a single-call message in the +shape the legacy class expects (XML by default, configurable), runs the +legacy ``_process_batch_item``, and returns the joined output string. + +This means the legacy class re-parses the synthesised message internally; +that's the cost of bridging two different protocols. For tools that don't +need this round-trip (i.e. anything new), write a native +:class:`~torchrl.envs.llm.agentic.Tool` instead. +""" +from __future__ import annotations + +import asyncio +import json +from collections.abc import Mapping +from typing import Any, Callable + +from ..protocols import TextPart, ToolContext, ToolError, ToolResult + + +def _default_render(name: str, args: Mapping[str, Any]) -> str: + body = json.dumps(dict(args), ensure_ascii=False) + return f'{body}' + + +def as_tool( + transform: Any, + *, + name: str, + input_schema: Mapping[str, Any] | None = None, + description: str = "", + wants_state: bool = False, + render_call: Callable[[str, Mapping[str, Any]], str] | None = None, +): + """Wrap a legacy tool transform as a new-style + :class:`~torchrl.envs.llm.agentic.Tool`. + + Args: + transform: An instance of a legacy + :class:`~torchrl.envs.llm.transforms.tools.ToolTransformBase` + subclass (or anything with a compatible + ``_process_batch_item(content: str, index: int) -> list[str] | None``). + name: Tool name to expose to the LLM. Should match the name the + legacy transform expects in the synthesised XML envelope (i.e. + the same string the model would write in + ````). + input_schema: JSON Schema dict for the LLM. If ``None``, a + permissive ``{"type": "object"}`` is used. + description: Tool description for the LLM. + wants_state: Set to ``True`` to receive a filtered TensorDict view + via ``ctx.state`` (mirrors the legacy ``pass_state_to_tools`` + knob). + render_call: Custom function to format ``(name, args)`` into the + string the legacy transform parses. Defaults to + ``{json args}``. + + Returns: + A :class:`Tool`-conforming object. + + Examples: + >>> from torchrl.envs.llm.transforms import PythonInterpreter # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.tools import as_tool + >>> tool = as_tool( + ... PythonInterpreter(persistent=True), + ... name="python", + ... input_schema={"type": "object", + ... "properties": {"code": {"type": "string"}}, + ... "required": ["code"]}, + ... ) + """ + return _LegacyToolAdapter( + transform, + name=name, + input_schema=input_schema or {"type": "object"}, + description=description, + wants_state=wants_state, + render_call=render_call or _default_render, + ) + + +class _LegacyToolAdapter: + """Tool that delegates to a legacy ``ToolTransformBase``-style object.""" + + output_schema = None + + def __init__( + self, + transform: Any, + *, + name: str, + input_schema: Mapping[str, Any], + description: str, + wants_state: bool, + render_call: Callable[[str, Mapping[str, Any]], str], + ) -> None: + # Per-instance attrs deliberately (rather than ClassVar) -- the + # adapter is a factory and each invocation produces a distinct + # tool with its own name/schema. + self.name = name + self.input_schema = dict(input_schema) + self.description = description + self.wants_state = wants_state + self._transform = transform + self._render = render_call + + async def setup(self) -> None: + # Legacy transforms don't have a uniform setup hook; honor it if + # the duck-typed object provides one. + hook = getattr(self._transform, "setup", None) + if callable(hook): + res = hook() + if asyncio.iscoroutine(res): + await res + + async def teardown(self) -> None: + hook = getattr(self._transform, "teardown", None) + if callable(hook): + res = hook() + if asyncio.iscoroutine(res): + await res + # The legacy PythonInterpreter has its own close/shutdown patterns; + # if exposed, call them. + for method_name in ("close", "shutdown"): + method = getattr(self._transform, method_name, None) + if callable(method): + try: + res = method() + if asyncio.iscoroutine(res): + await res + except Exception: # pragma: no cover -- defensive + pass + break + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: + rendered = self._render(self.name, args) + process = getattr(self._transform, "_process_batch_item", None) + if not callable(process): + raise ToolError( + f"legacy transform {type(self._transform).__name__!r} has " + "no _process_batch_item; not adaptable" + ) + # Legacy tools are sync; offload so we don't block the event loop. + try: + results = await asyncio.to_thread(process, rendered, 0) + except Exception as e: # pragma: no cover -- depends on legacy impl + raise ToolError(str(e)) from e + if not results: + return ToolResult.from_text("", meta={"adapter": "legacy"}) + text = "\n".join(str(r) for r in results) + return ToolResult( + parts=(TextPart(text=text),), + is_error=False, + meta={"adapter": "legacy", "count": len(results)}, + ) + + +__all__ = ["as_tool"]