From 500e59ab7431f9b97393e7534e43b641b9e2c2df Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 10 May 2026 21:17:58 +0100 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- docs/source/reference/llms_envs.rst | 108 +++++ test/llm/test_llm_transforms.py | 446 ++++++++++++++++++ torchrl/envs/llm/agentic/__init__.py | 61 +++ torchrl/envs/llm/agentic/parsers/__init__.py | 32 ++ torchrl/envs/llm/agentic/parsers/anthropic.py | 115 +++++ .../envs/llm/agentic/parsers/json_block.py | 88 ++++ torchrl/envs/llm/agentic/parsers/openai.py | 123 +++++ torchrl/envs/llm/agentic/parsers/xml.py | 92 ++++ torchrl/envs/llm/agentic/protocols.py | 272 +++++++++++ torchrl/envs/llm/agentic/repl/__init__.py | 25 + torchrl/envs/llm/agentic/repl/base.py | 108 +++++ torchrl/envs/llm/agentic/repl/jupyter.py | 210 +++++++++ torchrl/envs/llm/agentic/repl/subprocess.py | 262 ++++++++++ torchrl/envs/llm/agentic/sandbox/__init__.py | 60 +++ torchrl/envs/llm/agentic/sandbox/base.py | 192 ++++++++ torchrl/envs/llm/agentic/sandbox/docker.py | 70 +++ torchrl/envs/llm/agentic/sandbox/e2b.py | 64 +++ torchrl/envs/llm/agentic/sandbox/modal.py | 64 +++ .../llm/agentic/sandbox/subprocess_bwrap.py | 214 +++++++++ .../agentic/sandbox/subprocess_seatbelt.py | 184 ++++++++ torchrl/envs/llm/agentic/sandbox/unsafe.py | 163 +++++++ torchrl/envs/llm/agentic/schema.py | 102 ++++ 22 files changed, 3055 insertions(+) create mode 100644 torchrl/envs/llm/agentic/__init__.py create mode 100644 torchrl/envs/llm/agentic/parsers/__init__.py create mode 100644 torchrl/envs/llm/agentic/parsers/anthropic.py create mode 100644 torchrl/envs/llm/agentic/parsers/json_block.py create mode 100644 torchrl/envs/llm/agentic/parsers/openai.py create mode 100644 torchrl/envs/llm/agentic/parsers/xml.py create mode 100644 torchrl/envs/llm/agentic/protocols.py create mode 100644 torchrl/envs/llm/agentic/repl/__init__.py create mode 100644 torchrl/envs/llm/agentic/repl/base.py create mode 100644 torchrl/envs/llm/agentic/repl/jupyter.py create mode 100644 torchrl/envs/llm/agentic/repl/subprocess.py create mode 100644 torchrl/envs/llm/agentic/sandbox/__init__.py create mode 100644 torchrl/envs/llm/agentic/sandbox/base.py create mode 100644 torchrl/envs/llm/agentic/sandbox/docker.py create mode 100644 torchrl/envs/llm/agentic/sandbox/e2b.py create mode 100644 torchrl/envs/llm/agentic/sandbox/modal.py create mode 100644 torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py create mode 100644 torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py create mode 100644 torchrl/envs/llm/agentic/sandbox/unsafe.py create mode 100644 torchrl/envs/llm/agentic/schema.py diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst index d457889216f..879ba3e6d4f 100644 --- a/docs/source/reference/llms_envs.rst +++ b/docs/source/reference/llms_envs.rst @@ -28,3 +28,111 @@ The environment layer orchestrates data loading, tool execution, reward computat LLMHashingEnv make_mlgym MLGymWrapper + +Agentic toolkit (preview) +------------------------- + +.. currentmodule:: torchrl.envs.llm.agentic + +The :mod:`torchrl.envs.llm.agentic` package provides a SOTA, async-first +substrate for tool-calling agents. The headline orchestrator +(``ToolCompose``) lands in a follow-up commit; this preview ships the +contracts, parsers, sandboxing, and stateful REPLs that it builds on. + +Tool contracts +~~~~~~~~~~~~~~ + +A :class:`Tool` is a pure async object with a name, a JSON Schema +``input_schema``, and an async ``run(args, ctx)`` method returning a +:class:`ToolResult`. Calls flow through a :class:`ToolCallParser` (one of +the four built-ins below) which guarantees a stable ``call_id`` for +every invocation. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Tool + ToolContext + ToolResult + TextPart + JsonPart + ImagePart + FileRefPart + ParsedCall + ParseResult + ToolCallParser + +Parsers +~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.agentic.parsers + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + XMLToolCallParser + JSONToolCallParser + OpenAIToolCallParser + AnthropicToolUseParser + +Sandboxing +~~~~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.agentic.sandbox + +A :class:`Sandbox` is an async context manager that runs subprocess +commands with bounded resources, controlled filesystem access, and +opt-in network egress. The default backends are +:class:`BubblewrapSandbox` on Linux and :class:`SeatbeltSandbox` on +macOS; pick one explicitly or use :func:`default_sandbox`. + +For environments without those binaries, :class:`UnsafeSubprocessSandbox` +provides a no-isolation fallback that warns loudly on every +``open()``. Do not use it with untrusted model output. + +.. note:: + Apple has officially deprecated ``sandbox-exec``, but it still ships + with macOS 14+ and remains the most portable in-process isolation + primitive on that platform. For stronger or cross-platform + isolation, prefer :class:`DockerSandbox` (real implementation + tracked in the package TODO list). + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Sandbox + SandboxResult + ResourceLimits + BubblewrapSandbox + SeatbeltSandbox + UnsafeSubprocessSandbox + DockerSandbox + E2BSandbox + ModalSandbox + default_sandbox + +Stateful REPLs +~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.agentic.repl + +A :class:`Repl` runs stateful code inside a :class:`Sandbox` so an +agent can build up variables across multiple tool calls. The default +:class:`JupyterRepl` uses an IPython kernel for rich outputs (images, +JSON, plots) and clean restarts (optional dependency: +``jupyter_client``). :class:`SubprocessRepl` is a no-dep fallback that +trades rich display for portability. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + Repl + ReplResult + ReplDisplay + ReplError + JupyterRepl + SubprocessRepl diff --git a/test/llm/test_llm_transforms.py b/test/llm/test_llm_transforms.py index b6d0c1a6853..99e52973efd 100644 --- a/test/llm/test_llm_transforms.py +++ b/test/llm/test_llm_transforms.py @@ -718,3 +718,449 @@ def test_empty_history_handling(self, tokenizer): assert ("tokens", "prompt") in result.keys(True, True) tokens = result.get(("tokens", "prompt"), as_list=True) assert tokens[0].numel() > 0 + + +# --------------------------------------------------------------------------- +# Agentic toolkit (torchrl.envs.llm.agentic) +# --------------------------------------------------------------------------- + +import asyncio # noqa: E402 +import socket # noqa: E402 +import sys # noqa: E402 +import warnings # noqa: E402 + +from torchrl.envs.llm.agentic import ( # noqa: E402 + ParsedCall, + TextPart, + Tool, + ToolCallParser, + ToolContext, + ToolResult, + validate_args, +) +from torchrl.envs.llm.agentic.parsers import ( # noqa: E402 + AnthropicToolUseParser, + JSONToolCallParser, + OpenAIToolCallParser, + XMLToolCallParser, +) +from torchrl.envs.llm.agentic.repl import ( # noqa: E402 + SubprocessRepl, + _has_jupyter_client, +) +from torchrl.envs.llm.agentic.sandbox import ( # noqa: E402 + BubblewrapSandbox, + ResourceLimits, + SandboxError, + SeatbeltSandbox, + UnsafeSubprocessSandbox, + default_sandbox, +) +from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap # noqa: E402 +from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import ( # noqa: E402 + _has_sandbox_exec, +) + + +def _run(coro): + return asyncio.get_event_loop().run_until_complete(coro) if False else asyncio.run(coro) + + +class TestAgenticParsers: + """Per-parser conformance: parse, render_call round-trip, render_result, + stable call_id (parser-supplied or assigned). + """ + + @pytest.mark.parametrize( + "parser_cls", + [XMLToolCallParser, JSONToolCallParser, OpenAIToolCallParser, AnthropicToolUseParser], + ) + def test_implements_protocol(self, parser_cls): + p = parser_cls() + assert isinstance(p, ToolCallParser) + assert isinstance(p.name, str) and p.name + + def test_xml_parse_and_call_id(self): + p = XMLToolCallParser() + r = p.parse('{"text": "hi"}tail') + assert len(r.calls) == 1 + c = r.calls[0] + assert c.tool == "echo" + assert c.args == {"text": "hi"} + assert c.call_id == "t1" # tag becomes call_id when present + assert c.tag == "t1" + assert r.text == "tail" + + def test_xml_assigns_call_id_when_no_tag(self): + p = XMLToolCallParser() + r = p.parse('{}') + assert r.calls[0].call_id # non-empty + assert r.calls[0].tag is None + + def test_xml_round_trip(self): + p = XMLToolCallParser() + call = ParsedCall( + tool="echo", args={"text": "hi"}, call_id="abc", tag="abc" + ) + rendered = p.render_call(call) + re_parsed = p.parse(rendered) + assert re_parsed.calls[0].tool == "echo" + assert re_parsed.calls[0].args == {"text": "hi"} + assert re_parsed.calls[0].call_id == "abc" + + def test_xml_render_result(self): + p = XMLToolCallParser() + msg = p.render_result("c1", ToolResult.from_text("output")) + assert msg["role"] == "tool" + assert "c1" in msg["content"] + assert "output" in msg["content"] + + def test_json_block_parse_with_id(self): + p = JSONToolCallParser() + resp = json.dumps( + { + "message": "ok", + "tools": [{"tool": "echo", "args": {"x": 1}, "id": "j1"}], + } + ) + r = p.parse(resp) + assert r.text == "ok" + assert r.calls[0].tool == "echo" + assert r.calls[0].args == {"x": 1} + assert r.calls[0].call_id == "j1" + + def test_json_block_assigns_call_id(self): + p = JSONToolCallParser() + resp = json.dumps({"message": "", "tools": [{"tool": "x", "args": {}}]}) + r = p.parse(resp) + assert r.calls[0].call_id # uuid hex + + def test_json_block_invalid_json_falls_back_to_text(self): + p = JSONToolCallParser() + r = p.parse("not json at all") + assert r.text == "not json at all" + assert r.calls == () + + def test_openai_preserves_id_and_decodes_args(self): + p = OpenAIToolCallParser() + r = p.parse( + { + "role": "assistant", + "content": "thinking", + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q": "torchrl"}', + }, + } + ], + } + ) + assert r.calls[0].tool == "search" + assert r.calls[0].args == {"q": "torchrl"} + assert r.calls[0].call_id == "call_a" + + def test_openai_render_result_uses_tool_call_id(self): + p = OpenAIToolCallParser() + msg = p.render_result("call_a", ToolResult.from_text("done")) + assert msg["role"] == "tool" + assert msg["tool_call_id"] == "call_a" + assert msg["content"] == "done" + + def test_anthropic_extracts_text_and_tool_use(self): + p = AnthropicToolUseParser() + r = p.parse( + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search."}, + { + "type": "tool_use", + "id": "toolu_a", + "name": "search", + "input": {"q": "x"}, + }, + ], + } + ) + assert r.text == "Let me search." + assert r.calls[0].tool == "search" + assert r.calls[0].args == {"q": "x"} + assert r.calls[0].call_id == "toolu_a" + + def test_anthropic_render_result_uses_tool_use_id(self): + p = AnthropicToolUseParser() + msg = p.render_result( + "toolu_a", ToolResult.from_text("hit", is_error=False) + ) + assert msg["role"] == "user" + assert msg["content"][0]["type"] == "tool_result" + assert msg["content"][0]["tool_use_id"] == "toolu_a" + + def test_validate_args_required(self): + schema = { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + } + validate_args({"code": "print(1)"}, schema) + with pytest.raises(Exception): + validate_args({}, schema) + + def test_validate_args_type_mismatch(self): + schema = { + "type": "object", + "properties": {"n": {"type": "integer"}}, + } + validate_args({"n": 3}, schema) + with pytest.raises(Exception): + validate_args({"n": "three"}, schema) + + def test_tool_protocol_runtime_check(self): + class _T: + name = "t" + description = "d" + input_schema = {"type": "object", "properties": {}} + output_schema = None + wants_state = False + + async def run(self, args, ctx): + return ToolResult.from_text("ok") + + async def setup(self): + pass + + async def teardown(self): + pass + + assert isinstance(_T(), Tool) + + +class TestAgenticSandbox: + """Sandbox protocol conformance + sandbox-escape negatives.""" + + def test_unsafe_warns_on_open(self): + async def go(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + async with UnsafeSubprocessSandbox() as _s: + pass + assert any( + issubclass(w.category, UserWarning) for w in caught + ) + + _run(go()) + + def test_unsafe_runs_simple_command(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=5) + ) as s: + r = await s.run(["/bin/echo", "hello"]) + assert r.exit_code == 0 + assert r.stdout.strip() == "hello" + assert not r.timed_out + + _run(go()) + + def test_unsafe_timeout(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=0.2) + ) as s: + r = await s.run(["/bin/sleep", "5"]) + assert r.timed_out + + _run(go()) + + def test_resource_limits_narrow(self): + a = ResourceLimits(wall_seconds=10, network="full") + b = ResourceLimits(wall_seconds=2, network="none") + c = a.narrow(b) + assert c.wall_seconds == 2 + assert c.network == "none" + # Reverse direction: narrow keeps the strictest. + c2 = b.narrow(a) + assert c2.wall_seconds == 2 + assert c2.network == "none" + + def test_default_sandbox_picks_platform(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = default_sandbox() + if sys.platform.startswith("linux") and _has_bwrap: + assert isinstance(s, BubblewrapSandbox) + elif sys.platform == "darwin" and _has_sandbox_exec: + assert isinstance(s, SeatbeltSandbox) + else: + assert isinstance(s, UnsafeSubprocessSandbox) + + @pytest.mark.skipif( + not (sys.platform.startswith("linux") and _has_bwrap), + reason="bubblewrap not available", + ) + def test_bubblewrap_blocks_fs_escape(self, tmp_path): + """Writes outside fs_write_roots must fail.""" + write_root = tmp_path / "work" + write_root.mkdir() + outside = tmp_path / "forbidden" + + async def go(): + limits = ResourceLimits( + wall_seconds=5, + network="none", + fs_write_roots=(str(write_root),), + ) + async with BubblewrapSandbox(limits=limits) as s: + # Inside the write root: must succeed. + inside_path = write_root / "inside.txt" + r = await s.run( + [ + "/bin/sh", + "-c", + f"echo hi > {inside_path}", + ] + ) + assert r.exit_code == 0 + assert inside_path.read_text().strip() == "hi" + # Outside the write root: must fail. + r2 = await s.run( + [ + "/bin/sh", + "-c", + f"echo nope > {outside}", + ] + ) + assert r2.exit_code != 0 + assert not outside.exists() + + _run(go()) + + @pytest.mark.skipif( + not (sys.platform.startswith("linux") and _has_bwrap), + reason="bubblewrap not available", + ) + def test_bubblewrap_blocks_network(self): + """network='none' must block outbound TCP.""" + async def go(): + limits = ResourceLimits(wall_seconds=5, network="none") + async with BubblewrapSandbox(limits=limits) as s: + r = await s.run( + [ + "python3", + "-c", + "import socket; " + "socket.create_connection(('1.1.1.1', 80), timeout=2)", + ] + ) + assert r.exit_code != 0 + + _run(go()) + + @pytest.mark.skipif( + not (sys.platform == "darwin" and _has_sandbox_exec), + reason="sandbox-exec not available", + ) + def test_seatbelt_blocks_fs_escape(self, tmp_path): + write_root = tmp_path / "work" + write_root.mkdir() + outside = tmp_path / "forbidden" + + async def go(): + limits = ResourceLimits( + wall_seconds=5, + network="none", + fs_write_roots=(str(write_root),), + ) + async with SeatbeltSandbox(limits=limits) as s: + inside_path = write_root / "inside.txt" + r = await s.run( + ["/bin/sh", "-c", f"echo hi > {inside_path}"] + ) + assert r.exit_code == 0 + r2 = await s.run( + ["/bin/sh", "-c", f"echo nope > {outside}"] + ) + assert r2.exit_code != 0 + assert not outside.exists() + + _run(go()) + + +class TestAgenticRepl: + """REPL state, error capture, restart, timeout.""" + + def test_subprocess_repl_state_persists(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + async with SubprocessRepl(s) as r: + r1 = await r.execute("x = 41") + assert r1.error is None + r2 = await r.execute("print(x + 1)") + assert r2.error is None + assert r2.stdout.strip() == "42" + + _run(go()) + + def test_subprocess_repl_captures_errors(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + async with SubprocessRepl(s) as r: + res = await r.execute("1/0") + assert res.error is not None + assert res.error.ename == "ZeroDivisionError" + + _run(go()) + + def test_subprocess_repl_restart_clears_state(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + async with SubprocessRepl(s) as r: + await r.execute("y = 99") + await r.restart() + res = await r.execute("print(y)") + assert res.error is not None # NameError + + _run(go()) + + def test_subprocess_repl_timeout(self): + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=10) + ) as s: + async with SubprocessRepl(s) as r: + res = await r.execute( + "import time; time.sleep(5)", timeout=0.3 + ) + assert res.timed_out + + _run(go()) + + @pytest.mark.skipif( + not _has_jupyter_client, reason="jupyter_client not installed" + ) + @pytest.mark.slow + def test_jupyter_repl_state_persists(self): + from torchrl.envs.llm.agentic.repl import JupyterRepl + + async def go(): + async with UnsafeSubprocessSandbox( + ResourceLimits(wall_seconds=60) + ) as s: + async with JupyterRepl(s) as r: + r1 = await r.execute("x = 41", timeout=30) + assert r1.error is None, r1 + r2 = await r.execute("print(x + 1)", timeout=30) + assert r2.error is None + assert r2.stdout.strip() == "42" + + _run(go()) diff --git a/torchrl/envs/llm/agentic/__init__.py b/torchrl/envs/llm/agentic/__init__.py new file mode 100644 index 00000000000..84406ff8233 --- /dev/null +++ b/torchrl/envs/llm/agentic/__init__.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Agentic toolkit for ChatEnv. + +A first-class, async-first stack for LLM tool use. Drop a +:class:`~torchrl.envs.llm.agentic.ToolCompose` into a ``TransformedEnv`` +wrapping an unmodified :class:`~torchrl.envs.llm.ChatEnv`, register a few +:class:`~torchrl.envs.llm.agentic.Tool` instances, pick a parser, and you +have an agent loop with parallel dispatch, sandboxed execution, and a +stateful REPL. + +See ``docs/source/reference/llms_envs.rst`` and +``docs/source/tutorials/llm_agentic.rst`` for a walkthrough. +""" +# TODO: contributors please update as items are picked up. +# - streaming tool results (AsyncIterator[ToolEvent] from Tool.run) +# - per-tool token-budget accounting +# - E2B / Modal real implementations (stubs land first) +# - harmony parser (gpt-oss / o1-style) +# - Ray dispatcher (ToolCompose(parallel="ray")) +# - multimodal tool outputs (image / audio in ToolResult.parts) +# - structured-output validation against Tool.output_schema +# - per-tool retry / circuit breaker +# - tool-result caching (content-addressed) for replay +# - formal deprecation of legacy tool transforms once the new API soaks +from __future__ import annotations + +from .protocols import ( + FileRefPart, + ImagePart, + JsonPart, + ParsedCall, + ParseResult, + Tool, + ToolCallParser, + ToolContext, + ToolError, + ToolResult, + ToolResultPart, + TextPart, +) +from .schema import json_schema_from_pydantic, validate_args + +__all__ = [ + "FileRefPart", + "ImagePart", + "JsonPart", + "ParseResult", + "ParsedCall", + "TextPart", + "Tool", + "ToolCallParser", + "ToolContext", + "ToolError", + "ToolResult", + "ToolResultPart", + "json_schema_from_pydantic", + "validate_args", +] diff --git a/torchrl/envs/llm/agentic/parsers/__init__.py b/torchrl/envs/llm/agentic/parsers/__init__.py new file mode 100644 index 00000000000..48411dbe444 --- /dev/null +++ b/torchrl/envs/llm/agentic/parsers/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Pluggable parsers turning LLM responses into :class:`ParsedCall` items. + +Available parsers: + +- :class:`XMLToolCallParser` -- ``{...}`` blocks + embedded in the message body. Successor to the legacy + :class:`~torchrl.envs.llm.transforms.XMLBlockParser`. +- :class:`JSONToolCallParser` -- top-level JSON with ``message`` and + ``tools`` fields. Successor to + :class:`~torchrl.envs.llm.transforms.JSONCallParser`. +- :class:`OpenAIToolCallParser` -- structured ``tool_calls`` array on the + assistant message (OpenAI / vLLM-with-tools). +- :class:`AnthropicToolUseParser` -- ``tool_use`` content blocks + (Anthropic). +""" +from __future__ import annotations + +from .anthropic import AnthropicToolUseParser +from .json_block import JSONToolCallParser +from .openai import OpenAIToolCallParser +from .xml import XMLToolCallParser + +__all__ = [ + "AnthropicToolUseParser", + "JSONToolCallParser", + "OpenAIToolCallParser", + "XMLToolCallParser", +] diff --git a/torchrl/envs/llm/agentic/parsers/anthropic.py b/torchrl/envs/llm/agentic/parsers/anthropic.py new file mode 100644 index 00000000000..d2036e36071 --- /dev/null +++ b/torchrl/envs/llm/agentic/parsers/anthropic.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Anthropic ``tool_use`` parser. + +Reads ``tool_use`` content blocks from an assistant message and emits a +``tool_result`` block per call when rendering. Matches the Messages API +shape used by Claude. +""" +from __future__ import annotations + +import json +import uuid +from collections.abc import Mapping +from typing import Any, ClassVar + +from ..protocols import ParsedCall, ParseResult, ToolResult + + +class AnthropicToolUseParser: + """Parses Anthropic-style ``tool_use`` content blocks. + + Accepts either the full assistant message:: + + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search."}, + {"type": "tool_use", "id": "toolu_1", + "name": "search", "input": {"q": "x"}} + ] + } + + or a bare ``content`` list. Each block's ``id`` is preserved as + :attr:`ParsedCall.call_id`. + + Examples: + >>> p = AnthropicToolUseParser() + >>> resp = {"role": "assistant", "content": [ + ... {"type": "text", "text": "ok"}, + ... {"type": "tool_use", "id": "u1", + ... "name": "echo", "input": {"text": "hi"}}, + ... ]} + >>> r = p.parse(resp) + >>> r.text, r.calls[0].tool, r.calls[0].args, r.calls[0].call_id + ('ok', 'echo', {'text': 'hi'}, 'u1') + """ + + name: ClassVar[str] = "anthropic" + + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: + if isinstance(response, str): + try: + data: Any = json.loads(response) + except json.JSONDecodeError: + return ParseResult(text=response, calls=(), raw=response) + else: + data = response + if isinstance(data, Mapping): + content = data.get("content") + else: + content = data + if isinstance(content, str): + return ParseResult(text=content, calls=(), raw=response) + if not isinstance(content, list): + return ParseResult(text="", calls=(), raw=response) + text_parts: list[str] = [] + calls: list[ParsedCall] = [] + for block in content: + if not isinstance(block, Mapping): + continue + btype = block.get("type") + if btype == "text": + text_parts.append(str(block.get("text", ""))) + elif btype == "tool_use": + calls.append( + ParsedCall( + tool=str(block.get("name", "")), + args=dict(block.get("input") or {}), + call_id=str(block.get("id") or uuid.uuid4().hex), + tag=None, + ) + ) + return ParseResult( + text="\n".join(text_parts).strip(), + calls=tuple(calls), + raw=response, + ) + + def render_call(self, call: ParsedCall) -> str: + return json.dumps( + { + "type": "tool_use", + "id": call.call_id, + "name": call.tool, + "input": dict(call.args), + }, + ensure_ascii=False, + ) + + def render_result( + self, call_id: str, result: ToolResult + ) -> Mapping[str, Any]: + return { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": call_id, + "content": result.text, + "is_error": result.is_error, + } + ], + } diff --git a/torchrl/envs/llm/agentic/parsers/json_block.py b/torchrl/envs/llm/agentic/parsers/json_block.py new file mode 100644 index 00000000000..c290b72dbfb --- /dev/null +++ b/torchrl/envs/llm/agentic/parsers/json_block.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""JSON-block parser: top-level ``{"message": ..., "tools": [...]}``.""" +from __future__ import annotations + +import json +import uuid +from collections.abc import Mapping +from typing import Any, ClassVar + +from ..protocols import ParsedCall, ParseResult, ToolResult + + +class JSONToolCallParser: + """Parses LLM responses formatted as a single JSON object. + + Expected shape:: + + { + "message": "Let me search.", + "tools": [ + {"tool": "search", "args": {"query": "x"}, "id": "c1"}, + {"tool": "summarize", "args": {"text": "..."}} + ] + } + + The optional ``id`` field on each call is used as the stable + ``call_id``; when absent a uuid4 hex is assigned. Successor to + :class:`~torchrl.envs.llm.transforms.JSONCallParser`. + + Examples: + >>> p = JSONToolCallParser() + >>> resp = '{"message": "ok", "tools": [{"tool": "echo", "args": {"x": 1}}]}' + >>> r = p.parse(resp) + >>> r.text, r.calls[0].tool, r.calls[0].args + ('ok', 'echo', {'x': 1}) + """ + + name: ClassVar[str] = "json_block" + + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: + if isinstance(response, str): + try: + data = json.loads(response) + except json.JSONDecodeError: + return ParseResult(text=response, calls=(), raw=response) + else: + data = response + if not isinstance(data, Mapping): + return ParseResult(text=str(data), calls=(), raw=response) + tools_data = data.get("tools") or () + calls = tuple( + ParsedCall( + tool=str(c["tool"]), + args=dict(c.get("args") or {}), + call_id=str(c.get("id") or c.get("call_id") or uuid.uuid4().hex), + tag=c.get("tag"), + ) + for c in tools_data + ) + return ParseResult( + text=str(data.get("message", "")), + calls=calls, + raw=response, + ) + + def render_call(self, call: ParsedCall) -> str: + return json.dumps( + {"tool": call.tool, "args": dict(call.args), "id": call.call_id}, + ensure_ascii=False, + ) + + def render_result( + self, call_id: str, result: ToolResult + ) -> Mapping[str, Any]: + return { + "role": "tool", + "content": json.dumps( + { + "id": call_id, + "is_error": result.is_error, + "output": result.text, + }, + ensure_ascii=False, + ), + } diff --git a/torchrl/envs/llm/agentic/parsers/openai.py b/torchrl/envs/llm/agentic/parsers/openai.py new file mode 100644 index 00000000000..a785903748c --- /dev/null +++ b/torchrl/envs/llm/agentic/parsers/openai.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""OpenAI-compatible tool-call parser. + +Reads structured tool calls from the assistant message envelope +(``message.tool_calls`` or top-level ``tool_calls``), as produced by the +OpenAI Chat Completions API and any compatible server (vLLM with +``--enable-auto-tool-choice``, etc.). +""" +from __future__ import annotations + +import json +import uuid +from collections.abc import Mapping +from typing import Any, ClassVar + +from ..protocols import ParsedCall, ParseResult, ToolResult + + +class OpenAIToolCallParser: + """Parses OpenAI-style ``tool_calls`` from an assistant message. + + Accepts any of these shapes: + + - The full message dict:: + + {"role": "assistant", "content": "...", "tool_calls": [...]} + + - The choice dict:: + + {"message": {... "tool_calls": [...]}} + + - A bare list under ``tool_calls`` at the top level. + + Each call's ``id`` is preserved as :attr:`ParsedCall.call_id`. Arguments + are JSON-decoded from the ``function.arguments`` string. + + Examples: + >>> p = OpenAIToolCallParser() + >>> resp = { + ... "role": "assistant", + ... "content": "thinking...", + ... "tool_calls": [{ + ... "id": "call_a", + ... "type": "function", + ... "function": {"name": "echo", + ... "arguments": '{"text": "hi"}'}, + ... }], + ... } + >>> r = p.parse(resp) + >>> r.calls[0].tool, r.calls[0].args, r.calls[0].call_id + ('echo', {'text': 'hi'}, 'call_a') + """ + + name: ClassVar[str] = "openai" + + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: + if isinstance(response, str): + try: + data: Any = json.loads(response) + except json.JSONDecodeError: + return ParseResult(text=response, calls=(), raw=response) + else: + data = response + if isinstance(data, Mapping): + message = data.get("message", data) + content = message.get("content") or "" + tool_calls = message.get("tool_calls") or data.get("tool_calls") or () + else: + content = "" + tool_calls = data or () + calls: list[ParsedCall] = [] + for tc in tool_calls: + if not isinstance(tc, Mapping): + continue + fn = tc.get("function") or {} + raw_args = fn.get("arguments") + if isinstance(raw_args, str): + try: + args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError: + args = {"raw": raw_args} + else: + args = dict(raw_args or {}) + calls.append( + ParsedCall( + tool=str(fn.get("name", "")), + args=args, + call_id=str(tc.get("id") or uuid.uuid4().hex), + tag=None, + ) + ) + return ParseResult( + text=str(content) if isinstance(content, str) else "", + calls=tuple(calls), + raw=response, + ) + + def render_call(self, call: ParsedCall) -> str: + return json.dumps( + { + "id": call.call_id, + "type": "function", + "function": { + "name": call.tool, + "arguments": json.dumps(dict(call.args), ensure_ascii=False), + }, + }, + ensure_ascii=False, + ) + + def render_result( + self, call_id: str, result: ToolResult + ) -> Mapping[str, Any]: + # OpenAI shape: a "tool" role message with tool_call_id correlation. + return { + "role": "tool", + "tool_call_id": call_id, + "content": result.text, + "is_error": result.is_error, + } diff --git a/torchrl/envs/llm/agentic/parsers/xml.py b/torchrl/envs/llm/agentic/parsers/xml.py new file mode 100644 index 00000000000..e305b0f006b --- /dev/null +++ b/torchrl/envs/llm/agentic/parsers/xml.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""XML-block parser: ``{json}``.""" +from __future__ import annotations + +import json +import re +import uuid +from collections.abc import Mapping +from typing import Any, ClassVar + +from ..protocols import ParsedCall, ParseResult, ToolResult + + +class XMLToolCallParser: + r"""Parses XML-style tool blocks embedded in the assistant message. + + Format: + + {"query": "torchrl"} + + or, for argless tools: + + + + Successor to :class:`~torchrl.envs.llm.transforms.XMLBlockParser`. Differs + in that every :class:`ParsedCall` is given a stable ``call_id`` (the + ``tag`` if present, else a uuid4) so results can be correlated across + the dispatch boundary. + + Examples: + >>> p = XMLToolCallParser() + >>> r = p.parse('{"text": "hi"}ok') + >>> r.calls[0].tool, r.calls[0].args, r.calls[0].call_id, r.text + ('echo', {'text': 'hi'}, '1', 'ok') + """ + + name: ClassVar[str] = "xml" + + _re = re.compile( + r'[^"]+)"' + r'(?:\s+tag="(?P[^"]+)")?\s*>\s*' + r"(?P.*?)\s*", + re.DOTALL, + ) + + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: + text = ( + response + if isinstance(response, str) + else str(response.get("text", "")) + ) + calls: list[ParsedCall] = [] + + def repl(m: re.Match) -> str: + tag = m.group("tag") + body = m.group("body") + try: + args = json.loads(body) if body.strip() else {} + except json.JSONDecodeError: + args = {"raw": body} + calls.append( + ParsedCall( + tool=m.group("name"), + args=args, + call_id=tag if tag else uuid.uuid4().hex, + tag=tag, + ) + ) + return "" + + cleaned = self._re.sub(repl, text).strip() + return ParseResult(text=cleaned, calls=tuple(calls), raw=response) + + def render_call(self, call: ParsedCall) -> str: + tag = f' tag="{call.tag}"' if call.tag else "" + body = json.dumps(dict(call.args), ensure_ascii=False) + return f'{body}' + + def render_result( + self, call_id: str, result: ToolResult + ) -> Mapping[str, Any]: + body = result.text + prefix = "[error] " if result.is_error else "" + return { + "role": "tool", + "content": ( + f'{prefix}{body}' + ), + } diff --git a/torchrl/envs/llm/agentic/protocols.py b/torchrl/envs/llm/agentic/protocols.py new file mode 100644 index 00000000000..308e8d4a4a1 --- /dev/null +++ b/torchrl/envs/llm/agentic/protocols.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Core protocols and value types for the agentic toolkit. + +Three layered concerns: + +- :class:`Tool` -- a unit an LLM can invoke by name. Async-first. +- :class:`ToolCallParser` -- turns an LLM response into structured + :class:`ParsedCall` items and renders results back into the family's + message shape. +- (See :class:`~torchrl.envs.llm.agentic.sandbox.Sandbox` and + :class:`~torchrl.envs.llm.agentic.repl.Repl` for isolation and state.) + +Stable ``call_id`` invariant: every :class:`ParsedCall` carries a +``call_id`` (parser-supplied if the family provides one -- OpenAI ``id``, +Anthropic ``tool_use_id`` -- else a parser-assigned uuid4). Round-trips +through :meth:`ToolCallParser.render_result` so downstream consumers can +correlate calls and results. +""" +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar, Literal, Protocol, runtime_checkable + +from tensordict import TensorDictBase + + +# ----- result parts ----- + +@dataclass(frozen=True, slots=True) +class TextPart: + """A text fragment of a :class:`ToolResult`.""" + + text: str + kind: Literal["text"] = "text" + + +@dataclass(frozen=True, slots=True) +class JsonPart: + """A JSON-serialisable structured fragment of a :class:`ToolResult`.""" + + data: Any + kind: Literal["json"] = "json" + + +@dataclass(frozen=True, slots=True) +class ImagePart: + """An image fragment of a :class:`ToolResult` (raw bytes + media type).""" + + data: bytes + media_type: str = "image/png" + kind: Literal["image"] = "image" + + +@dataclass(frozen=True, slots=True) +class FileRefPart: + """A reference to a file produced by a tool (path inside the sandbox).""" + + path: str + media_type: str | None = None + kind: Literal["file_ref"] = "file_ref" + + +ToolResultPart = TextPart | JsonPart | ImagePart | FileRefPart + + +@dataclass(frozen=True, slots=True) +class ToolResult: + """The output of a single :meth:`Tool.run` invocation. + + Attributes: + parts: Ordered tuple of result fragments. ``parts[0]`` is conventionally + text. Most call sites only need ``result.text``. + is_error: Whether the tool raised or otherwise produced an error. + ``parts[0]`` should describe the error when ``True``. + meta: Free-form metadata (timing, tokens used, raw provider payload). + """ + + parts: tuple[ToolResultPart, ...] = () + is_error: bool = False + meta: Mapping[str, Any] = field(default_factory=dict) + + @property + def text(self) -> str: + """Concatenation of all :class:`TextPart` and stringified + :class:`JsonPart` content. Convenience for the common case.""" + out: list[str] = [] + for p in self.parts: + if isinstance(p, TextPart): + out.append(p.text) + elif isinstance(p, JsonPart): + import json as _json + + out.append(_json.dumps(p.data, ensure_ascii=False)) + elif isinstance(p, FileRefPart): + out.append(f"") + elif isinstance(p, ImagePart): + out.append(f"") + return "\n".join(out) + + @classmethod + def from_text( + cls, + text: str, + *, + is_error: bool = False, + meta: Mapping[str, Any] | None = None, + ) -> ToolResult: + """Shorthand for the common single-text-part result.""" + return cls( + parts=(TextPart(text=text),), + is_error=is_error, + meta=dict(meta or {}), + ) + + +@dataclass +class ToolError(Exception): + """Raised by tools to signal a structured failure. + + Catching this in :class:`ToolCompose` produces a + :class:`ToolResult` with ``is_error=True``. Anything else surfaces as + an unstructured error (still wrapped, but flagged in ``meta``). + """ + + message: str + detail: Mapping[str, Any] = field(default_factory=dict) + + def __str__(self) -> str: + return self.message + + +# ----- call / parse types ----- + +@dataclass(frozen=True, slots=True) +class ParsedCall: + """A single tool invocation parsed out of an LLM response. + + Attributes: + tool: The name of the tool to invoke. + args: Already-decoded keyword arguments. Validation against + :attr:`Tool.input_schema` happens in :class:`ToolCompose`. + call_id: Stable identifier (parser-assigned if not present in the + source). Round-trips through :meth:`ToolCallParser.render_result`. + tag: Optional human-visible label (back-compat with + ``ExecuteToolsInOrder``). + """ + + tool: str + args: Mapping[str, Any] + call_id: str + tag: str | None = None + + +@dataclass(frozen=True, slots=True) +class ParseResult: + """Output of :meth:`ToolCallParser.parse`. + + Attributes: + text: Cleaned message body with tool-call syntax stripped (when the + family embeds calls in the text -- XML, JSON-block). Empty for + providers where calls live in a structured field (OpenAI, + Anthropic). + calls: Calls in the order the model emitted them. + raw: The original response, for round-trip and debugging. + """ + + text: str + calls: tuple[ParsedCall, ...] + raw: Any = None + + +# ----- context passed to a Tool ----- + +@dataclass +class ToolContext: + """Per-call context handed to :meth:`Tool.run`. + + Attributes: + call_id: The :attr:`ParsedCall.call_id`. Stable across this turn. + tag: Optional :attr:`ParsedCall.tag`. + state: Read-only filtered view of the env state. Only populated when + the owning :class:`ToolCompose` has ``pass_state_to_tools=True`` + *and* the tool has ``wants_state=True``. + sandbox: The compose-level sandbox, if any. Tools may also hold + their own sandbox by reference. + repl: The compose-level REPL, if any. + compose: Back-reference to the owning :class:`ToolCompose` for + tool-to-tool dispatch from inside a tool body. + """ + + call_id: str + tag: str | None = None + state: TensorDictBase | None = None + sandbox: Any | None = None + repl: Any | None = None + compose: Any | None = None + + +# ----- protocols ----- + +@runtime_checkable +class Tool(Protocol): + """A unit invoked by name from an LLM response. + + Subclasses (or duck-typed equivalents) declare ``name``, ``description``, + and ``input_schema`` (JSON Schema dict) at the class level, and implement + an async :meth:`run`. + + A tool may opt in to receiving env state via the ``wants_state`` class + attribute -- :class:`ToolCompose` will populate ``ctx.state`` when both + sides agree. + + Example: + >>> from torchrl.envs.llm.agentic import Tool, ToolContext, ToolResult + >>> class EchoTool: + ... name = "echo" + ... description = "Returns its input." + ... input_schema = {"type": "object", + ... "properties": {"text": {"type": "string"}}, + ... "required": ["text"]} + ... output_schema = None + ... wants_state = False + ... async def run(self, args, ctx): + ... return ToolResult.from_text(args["text"]) + ... async def setup(self): pass + ... async def teardown(self): pass + """ + + name: ClassVar[str] + description: ClassVar[str] + input_schema: ClassVar[Mapping[str, Any]] + output_schema: ClassVar[Mapping[str, Any] | None] + wants_state: ClassVar[bool] + + async def run( + self, args: Mapping[str, Any], ctx: ToolContext + ) -> ToolResult: ... + + async def setup(self) -> None: ... + + async def teardown(self) -> None: ... + + +@runtime_checkable +class ToolCallParser(Protocol): + """Parses an LLM response into :class:`ParsedCall` items and renders + results back into the family's message shape. + + Implementations must guarantee: + + 1. :meth:`parse` is pure and synchronous. + 2. Every returned :class:`ParsedCall` has a non-empty :attr:`call_id`. + 3. ``parse -> render_call`` round-trips for calls produced by + :meth:`parse` (within the same parser family). + 4. :meth:`render_result` produces a mapping suitable for one new + message in :class:`~torchrl.data.llm.History` (keys at minimum: + ``role``, ``content``). + """ + + name: ClassVar[str] + + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: ... + + def render_call(self, call: ParsedCall) -> str: ... + + def render_result( + self, call_id: str, result: ToolResult + ) -> Mapping[str, Any]: ... diff --git a/torchrl/envs/llm/agentic/repl/__init__.py b/torchrl/envs/llm/agentic/repl/__init__.py new file mode 100644 index 00000000000..e196695940d --- /dev/null +++ b/torchrl/envs/llm/agentic/repl/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""REPL backends for the agentic toolkit. + +- :class:`JupyterRepl` -- IPython-kernel-backed; rich outputs, clean + restarts. Optional dependency on ``jupyter_client``. +- :class:`SubprocessRepl` -- persistent ``python3`` subprocess; no extra + dependency, no rich display. +""" +from __future__ import annotations + +from .base import Repl, ReplDisplay, ReplError, ReplResult +from .jupyter import JupyterRepl, _has_jupyter_client +from .subprocess import SubprocessRepl + +__all__ = [ + "JupyterRepl", + "Repl", + "ReplDisplay", + "ReplError", + "ReplResult", + "SubprocessRepl", +] diff --git a/torchrl/envs/llm/agentic/repl/base.py b/torchrl/envs/llm/agentic/repl/base.py new file mode 100644 index 00000000000..7071565013a --- /dev/null +++ b/torchrl/envs/llm/agentic/repl/base.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""REPL protocol and value types. + +A :class:`Repl` runs stateful code inside a :class:`Sandbox`. State persists +across :meth:`execute` calls until :meth:`restart`. :meth:`interrupt` +preserves state but cancels the current execution. Timeouts surface as +``ReplResult.timed_out=True`` rather than raising. +""" +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any, ClassVar, Literal, Protocol, runtime_checkable + +from ..sandbox.base import Sandbox + + +@dataclass(frozen=True, slots=True) +class ReplDisplay: + """A rich output (image, JSON, HTML) emitted via Jupyter's display + protocol. Subprocess REPLs emit nothing here. + """ + + media_type: str + data: Any + + +@dataclass(frozen=True, slots=True) +class ReplError: + """Structured error from the kernel (exception name, value, traceback).""" + + ename: str + evalue: str + traceback: str = "" + + +@dataclass(frozen=True, slots=True) +class ReplResult: + """Outcome of one :meth:`Repl.execute` invocation. + + Attributes: + stdout: Captured stdout. + stderr: Captured stderr. + display: Rich outputs in emit order. + error: Structured error, if any. + timed_out: ``True`` if execution hit the timeout. + execution_count: Monotonic counter (Jupyter); ``-1`` for subprocess. + """ + + stdout: str = "" + stderr: str = "" + display: tuple[ReplDisplay, ...] = () + error: ReplError | None = None + timed_out: bool = False + execution_count: int = -1 + + @property + def text(self) -> str: + """Convenience: stdout + stderr + (error.evalue if error).""" + out: list[str] = [] + if self.stdout: + out.append(self.stdout) + if self.stderr: + out.append(self.stderr) + if self.error: + out.append(f"{self.error.ename}: {self.error.evalue}") + return "\n".join(out).strip() + + +@runtime_checkable +class Repl(Protocol): + """Stateful code-execution session. + + Lifecycle: ``open()`` is idempotent and required before ``execute()``; + ``close()`` releases the kernel. Use as ``async with repl:`` to bracket. + + Invariants: + + - :meth:`execute` is stateful (variables persist) until :meth:`restart`. + - :meth:`interrupt` does not lose state. + - :meth:`execute` never raises on user-code errors; errors surface in + :attr:`ReplResult.error`. Infrastructure failures raise. + """ + + name: ClassVar[str] + sandbox: Sandbox + + async def open(self) -> None: ... + + async def close(self) -> None: ... + + async def __aenter__(self) -> Repl: ... + + async def __aexit__(self, exc_type, exc, tb) -> None: ... + + async def execute( + self, code: str, *, timeout: float | None = None + ) -> ReplResult: ... + + async def interrupt(self) -> None: ... + + async def restart(self) -> None: ... + + +__all__ = ["Repl", "ReplDisplay", "ReplError", "ReplResult"] diff --git a/torchrl/envs/llm/agentic/repl/jupyter.py b/torchrl/envs/llm/agentic/repl/jupyter.py new file mode 100644 index 00000000000..17e955ae318 --- /dev/null +++ b/torchrl/envs/llm/agentic/repl/jupyter.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Jupyter-kernel REPL. + +Spawns an IPython kernel via :mod:`jupyter_client` and drives it through +the standard ZeroMQ channels. Supports rich display outputs, clean +restarts, and proper interrupts. + +Optional dependency: install ``jupyter_client`` and ``ipykernel`` to use. +The import is gated by ``_has_jupyter_client`` and never imported at module +top level for the protocol path -- only inside :class:`JupyterRepl` methods. +""" +from __future__ import annotations + +import asyncio +import importlib.util +import queue +from typing import Any, ClassVar + +from torchrl._utils import logger as torchrl_logger + +from ..sandbox.base import Sandbox, SandboxError +from .base import Repl, ReplDisplay, ReplError, ReplResult + +_has_jupyter_client = importlib.util.find_spec("jupyter_client") is not None + + +_KERNEL_STARTUP_TIMEOUT = 30.0 + + +class JupyterRepl: + """IPython-kernel-backed REPL with rich outputs. + + Args: + sandbox: The :class:`Sandbox` the kernel runs inside. Today the + kernel binary is launched in the host process; binding it to a + sandbox is on the TODO list (see ``__init__.py``). Treat the + sandbox as advisory until then. + kernel_name: Jupyter kernel spec name (default ``"python3"``). + + Raises: + ImportError: at construction time if ``jupyter_client`` is not + installed. + + Examples: + >>> import asyncio # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.sandbox import UnsafeSubprocessSandbox + >>> from torchrl.envs.llm.agentic.repl import JupyterRepl + >>> async def go(): + ... async with UnsafeSubprocessSandbox() as s: + ... async with JupyterRepl(s) as r: + ... await r.execute("x = 41") + ... return (await r.execute("print(x + 1)")).stdout.strip() + """ + + name: ClassVar[str] = "jupyter" + + def __init__( + self, + sandbox: Sandbox, + *, + kernel_name: str = "python3", + ) -> None: + if not _has_jupyter_client: + raise ImportError( + "JupyterRepl requires jupyter_client. Install with " + "`pip install jupyter_client ipykernel`." + ) + self.sandbox = sandbox + self._kernel_name = kernel_name + self._km: Any = None + self._kc: Any = None + self._exec_count = 0 + + async def open(self) -> None: + if self._km is not None: + return + from jupyter_client.manager import KernelManager + + km = KernelManager(kernel_name=self._kernel_name) + + def _start() -> None: + km.start_kernel() + + await asyncio.get_running_loop().run_in_executor(None, _start) + kc = km.client() + kc.start_channels() + try: + await asyncio.get_running_loop().run_in_executor( + None, lambda: kc.wait_for_ready(timeout=_KERNEL_STARTUP_TIMEOUT) + ) + except RuntimeError as e: + kc.stop_channels() + km.shutdown_kernel(now=True) + raise SandboxError(f"jupyter kernel did not become ready: {e}") from e + self._km = km + self._kc = kc + + async def close(self) -> None: + if self._kc is not None: + try: + self._kc.stop_channels() + except Exception: # pragma: no cover + pass + self._kc = None + if self._km is not None: + try: + self._km.shutdown_kernel(now=True) + except Exception: # pragma: no cover + pass + self._km = None + + async def __aenter__(self) -> JupyterRepl: + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + async def execute( + self, code: str, *, timeout: float | None = None + ) -> ReplResult: + if self._kc is None: + raise SandboxError("REPL is not open; call open() first") + msg_id: str = self._kc.execute(code) + loop = asyncio.get_running_loop() + stdout_chunks: list[str] = [] + stderr_chunks: list[str] = [] + displays: list[ReplDisplay] = [] + error: ReplError | None = None + + try: + while True: + msg = await loop.run_in_executor( + None, + lambda: _safe_get_iopub(self._kc, timeout or 1e9), + ) + if msg is None: + return ReplResult(timed_out=True) + parent = msg.get("parent_header") or {} + if parent.get("msg_id") != msg_id: + continue + mtype = msg.get("msg_type") + content = msg.get("content") or {} + if mtype == "stream": + if content.get("name") == "stdout": + stdout_chunks.append(content.get("text", "")) + else: + stderr_chunks.append(content.get("text", "")) + elif mtype in ("execute_result", "display_data"): + data = content.get("data") or {} + for media_type, payload in data.items(): + displays.append( + ReplDisplay(media_type=media_type, data=payload) + ) + elif mtype == "error": + error = ReplError( + ename=str(content.get("ename", "")), + evalue=str(content.get("evalue", "")), + traceback="\n".join(content.get("traceback") or ()), + ) + elif mtype == "status": + if content.get("execution_state") == "idle": + break + except asyncio.TimeoutError: + try: + if self._km is not None: + self._km.interrupt_kernel() + except Exception: # pragma: no cover + pass + return ReplResult(timed_out=True) + self._exec_count += 1 + return ReplResult( + stdout="".join(stdout_chunks), + stderr="".join(stderr_chunks), + display=tuple(displays), + error=error, + timed_out=False, + execution_count=self._exec_count, + ) + + async def interrupt(self) -> None: + if self._km is not None: + try: + self._km.interrupt_kernel() + except Exception: # pragma: no cover + torchrl_logger.warning("jupyter interrupt failed", exc_info=True) + + async def restart(self) -> None: + if self._km is None: + await self.open() + return + try: + self._km.restart_kernel(now=True) + except Exception: # pragma: no cover + await self.close() + await self.open() + self._exec_count = 0 + + +def _safe_get_iopub(kc: Any, timeout: float) -> Any | None: + try: + return kc.get_iopub_msg(timeout=timeout) + except queue.Empty: + return None + + +__all__ = ["JupyterRepl"] diff --git a/torchrl/envs/llm/agentic/repl/subprocess.py b/torchrl/envs/llm/agentic/repl/subprocess.py new file mode 100644 index 00000000000..3dbda7a7395 --- /dev/null +++ b/torchrl/envs/llm/agentic/repl/subprocess.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Subprocess-backed REPL. + +Spawns a Python subprocess inside a :class:`Sandbox` and feeds it code via +stdin, reading stdout/stderr after each delimiter. State persists across +:meth:`execute` calls because the subprocess is long-lived. No rich +display protocol -- use :class:`JupyterRepl` for that. + +Implementation note: this is *not* a perfect REPL. Each ``execute`` call +sends the full block + a sentinel print; we read until the sentinel +appears. Errors are returned as a :class:`ReplError` parsed from stderr. +""" +from __future__ import annotations + +import asyncio +import os +import signal +import textwrap +import uuid +from typing import ClassVar + +from ..sandbox.base import Sandbox, SandboxError +from .base import Repl, ReplError, ReplResult + + +_BOOT = textwrap.dedent( + """\ + import sys, traceback + _NS = {} + while True: + try: + line = sys.stdin.readline() + if not line: + break + sentinel, n_lines = line.strip().split(' ', 1) + n_lines = int(n_lines) + code = ''.join(sys.stdin.readline() for _ in range(n_lines)) + try: + exec(compile(code, '', 'exec'), _NS) + err = None + except BaseException: + err = traceback.format_exc() + sys.stdout.write(sentinel + '_END\\n') + sys.stdout.flush() + if err is not None: + sys.stderr.write(err) + sys.stderr.write(sentinel + '_ERR\\n') + sys.stderr.flush() + else: + sys.stderr.write(sentinel + '_OK\\n') + sys.stderr.flush() + except Exception: + traceback.print_exc() + """ +) + + +class SubprocessRepl: + """Persistent Python subprocess used as a REPL. + + Args: + sandbox: The :class:`Sandbox` the subprocess runs inside. Must be + opened separately (or via ``async with``). + python_argv: Argv used to launch the interpreter (default + ``["python3", "-u", "-c", _BOOT]``). + + Examples: + >>> import asyncio # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.sandbox import UnsafeSubprocessSandbox + >>> from torchrl.envs.llm.agentic.repl import SubprocessRepl + >>> async def go(): + ... async with UnsafeSubprocessSandbox() as s: + ... async with SubprocessRepl(s) as r: + ... await r.execute("x = 1") + ... out = await r.execute("print(x)") + ... return out.stdout.strip() + >>> asyncio.run(go()) + """ + + name: ClassVar[str] = "subprocess" + + def __init__( + self, + sandbox: Sandbox, + *, + python_argv: tuple[str, ...] = ("python3", "-u", "-c", _BOOT), + ) -> None: + self.sandbox = sandbox + self._argv = python_argv + self._proc: asyncio.subprocess.Process | None = None + self._lock = asyncio.Lock() + + async def open(self) -> None: + if self._proc is not None and self._proc.returncode is None: + return + # Bypass sandbox.run() for the long-lived process: we need the + # subprocess handle, not just stdout/stderr at the end. We honor + # sandbox lifecycle but spawn the process directly inside it. + # For UnsafeSubprocessSandbox this is plain create_subprocess_exec; + # for hardened backends, we ask the sandbox to wrap the argv. + argv = await _wrap_argv_via_sandbox(self.sandbox, self._argv) + try: + self._proc = await asyncio.create_subprocess_exec( + *argv, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + except FileNotFoundError as e: + raise SandboxError(f"could not start REPL: {e}") from e + + async def close(self) -> None: + if self._proc is None: + return + try: + if self._proc.returncode is None: + self._proc.kill() + try: + await asyncio.wait_for(self._proc.wait(), timeout=2.0) + except asyncio.TimeoutError: # pragma: no cover + pass + finally: + self._proc = None + + async def __aenter__(self) -> SubprocessRepl: + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + async def execute( + self, code: str, *, timeout: float | None = None + ) -> ReplResult: + if self._proc is None or self._proc.returncode is not None: + raise SandboxError("REPL is not running; call open() first") + async with self._lock: + sentinel = "S" + uuid.uuid4().hex + n_lines = code.count("\n") + 1 + payload = f"{sentinel} {n_lines}\n{code}" + if not payload.endswith("\n"): + payload += "\n" + assert self._proc.stdin is not None + self._proc.stdin.write(payload.encode("utf-8")) + try: + await self._proc.stdin.drain() + except (BrokenPipeError, ConnectionResetError) as e: + raise SandboxError(f"REPL stdin closed: {e}") from e + try: + stdout, stderr = await asyncio.wait_for( + self._read_until_sentinels(sentinel), timeout=timeout + ) + timed_out = False + err: ReplError | None = None + if stderr.endswith(f"{sentinel}_ERR\n"): + body = stderr[: -len(f"{sentinel}_ERR\n")] + err = _parse_traceback(body) + else: + # Strip the OK marker. + if stderr.endswith(f"{sentinel}_OK\n"): + stderr = stderr[: -len(f"{sentinel}_OK\n")] + stdout_clean = stdout + if stdout_clean.endswith(f"{sentinel}_END\n"): + stdout_clean = stdout_clean[: -len(f"{sentinel}_END\n")] + return ReplResult( + stdout=stdout_clean, + stderr=stderr, + error=err, + timed_out=False, + execution_count=-1, + ) + except asyncio.TimeoutError: + # Send SIGINT and let the boot loop recover. State is + # preserved unless the user code is in an uninterruptible + # syscall, in which case the user must call restart(). + try: + if self._proc.pid is not None: + os.kill(self._proc.pid, signal.SIGINT) + except ProcessLookupError: # pragma: no cover + pass + return ReplResult(stdout="", stderr="", timed_out=True) + + async def interrupt(self) -> None: + if self._proc is None or self._proc.returncode is not None: + return + try: + if self._proc.pid is not None: + os.kill(self._proc.pid, signal.SIGINT) + except ProcessLookupError: # pragma: no cover + pass + + async def restart(self) -> None: + await self.close() + await self.open() + + async def _read_until_sentinels( + self, sentinel: str + ) -> tuple[str, str]: + # Read stdout until "_END\n" appears, then drain stderr + # until "_OK\n" or "_ERR\n" appears. + assert self._proc is not None + out_buf: list[bytes] = [] + end = f"{sentinel}_END\n".encode() + assert self._proc.stdout is not None + while True: + chunk = await self._proc.stdout.readline() + if not chunk: + break + out_buf.append(chunk) + if chunk == end: + break + err_buf: list[bytes] = [] + ok = f"{sentinel}_OK\n".encode() + e_err = f"{sentinel}_ERR\n".encode() + assert self._proc.stderr is not None + while True: + chunk = await self._proc.stderr.readline() + if not chunk: + break + err_buf.append(chunk) + if chunk == ok or chunk == e_err: + break + return ( + b"".join(out_buf).decode("utf-8", errors="replace"), + b"".join(err_buf).decode("utf-8", errors="replace"), + ) + + +def _parse_traceback(tb: str) -> ReplError: + """Parse the last line of a traceback into ``ename: evalue``.""" + lines = [line for line in tb.splitlines() if line] + if not lines: + return ReplError(ename="Error", evalue="", traceback=tb) + last = lines[-1] + if ":" in last: + ename, _, evalue = last.partition(":") + return ReplError(ename=ename.strip(), evalue=evalue.strip(), traceback=tb) + return ReplError(ename=last.strip(), evalue="", traceback=tb) + + +async def _wrap_argv_via_sandbox(sandbox: Sandbox, argv: tuple[str, ...]) -> list[str]: + """Best-effort: ask the sandbox to compute the prefixed argv, fallback + to the raw argv if the backend doesn't support pre-wrapping. + """ + builder = getattr(sandbox, "_build_argv", None) + if callable(builder): + try: + return list(builder(list(argv), sandbox.limits, None)) + except TypeError: + try: + return list(builder(list(argv), sandbox.limits)) + except Exception: + return list(argv) + except Exception: + return list(argv) + return list(argv) + + +__all__ = ["SubprocessRepl"] diff --git a/torchrl/envs/llm/agentic/sandbox/__init__.py b/torchrl/envs/llm/agentic/sandbox/__init__.py new file mode 100644 index 00000000000..ffef10b1fa2 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Sandbox backends for the agentic toolkit. + +The default :func:`default_sandbox` picks bubblewrap on Linux, +sandbox-exec on macOS, and falls back to :class:`UnsafeSubprocessSandbox` +elsewhere (with a :class:`UserWarning`). +""" +from __future__ import annotations + +import shutil +import sys +import warnings + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .docker import DockerSandbox +from .e2b import E2BSandbox +from .modal import ModalSandbox +from .subprocess_bwrap import BubblewrapSandbox, _has_bwrap +from .subprocess_seatbelt import SeatbeltSandbox, _has_sandbox_exec +from .unsafe import UnsafeSubprocessSandbox + + +def default_sandbox(limits: ResourceLimits | None = None) -> Sandbox: + """Return the best available sandbox for the current platform. + + - Linux with ``bwrap`` on PATH -> :class:`BubblewrapSandbox`. + - macOS with ``sandbox-exec`` on PATH -> :class:`SeatbeltSandbox`. + - Otherwise -> :class:`UnsafeSubprocessSandbox` with a warning. + """ + if sys.platform.startswith("linux") and _has_bwrap: + return BubblewrapSandbox(limits=limits) + if sys.platform == "darwin" and _has_sandbox_exec: + return SeatbeltSandbox(limits=limits) + warnings.warn( + "No hardened sandbox backend is available on this platform " + f"({sys.platform!r}). Falling back to UnsafeSubprocessSandbox; " + "this is fine for tests but NOT for running untrusted model " + "output.", + UserWarning, + stacklevel=2, + ) + return UnsafeSubprocessSandbox(limits=limits) + + +__all__ = [ + "BubblewrapSandbox", + "DockerSandbox", + "E2BSandbox", + "ModalSandbox", + "ResourceLimits", + "Sandbox", + "SandboxError", + "SandboxResult", + "SeatbeltSandbox", + "UnsafeSubprocessSandbox", + "default_sandbox", +] diff --git a/torchrl/envs/llm/agentic/sandbox/base.py b/torchrl/envs/llm/agentic/sandbox/base.py new file mode 100644 index 00000000000..c35e3d0aed9 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/base.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Sandbox protocol and value types. + +A :class:`Sandbox` is an async context manager owning an isolated execution +environment. :meth:`Sandbox.run` launches a subprocess inside it, +:meth:`Sandbox.write_file` and :meth:`Sandbox.read_file` mediate I/O. The +default backends -- :class:`BubblewrapSandbox` (Linux) and +:class:`SeatbeltSandbox` (macOS) -- enforce filesystem and network isolation +via OS-bundled tools. + +For environments where neither is available, +:class:`UnsafeSubprocessSandbox` provides a no-op fallback that runs a bare +subprocess with no isolation. It emits a ``UserWarning`` on every +:meth:`open` call so the lack of containment is impossible to miss. +""" +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any, ClassVar, Literal, Protocol, runtime_checkable + + +_NetworkPolicy = Literal["none", "loopback", "allowlist", "full"] + + +class SandboxError(RuntimeError): + """Raised on sandbox infrastructure failures (launch, kernel error, etc.). + + Tool processes that exit non-zero do *not* raise; the non-zero status is + surfaced via :attr:`SandboxResult.exit_code`. + """ + + +@dataclass(frozen=True, slots=True) +class ResourceLimits: + """Per-sandbox or per-call resource limits. + + Attributes: + cpu_seconds: Soft CPU budget. ``None`` means unlimited. + wall_seconds: Wall-clock timeout. ``None`` means unlimited. + memory_bytes: Address-space cap. ``None`` means unlimited. + network: Policy for outbound network. ``"none"`` blocks all sockets, + ``"loopback"`` allows 127.0.0.0/8 only, ``"allowlist"`` consults + :attr:`network_allowlist`, ``"full"`` is unrestricted. + network_allowlist: ``host:port`` strings, used only when + ``network == "allowlist"``. + fs_read_roots: Absolute paths the sandbox may read from. Empty means + backend default (typically ``/`` read-only on Linux/macOS). + fs_write_roots: Absolute paths the sandbox may write to. Empty means + no writes allowed. + max_processes: Cap on concurrent subprocesses. ``None`` for unlimited. + env: Environment-variable allowlist. ``None`` means a clean env with + only ``PATH``, ``HOME``, ``LANG``. + """ + + cpu_seconds: float | None = 30.0 + wall_seconds: float | None = 60.0 + memory_bytes: int | None = 512 * 1024 * 1024 + network: _NetworkPolicy = "none" + network_allowlist: tuple[str, ...] = () + fs_read_roots: tuple[str, ...] = () + fs_write_roots: tuple[str, ...] = () + max_processes: int | None = 32 + env: Mapping[str, str] | None = None + + def narrow(self, other: ResourceLimits | None) -> ResourceLimits: + """Return a new :class:`ResourceLimits` that is at most as permissive + as both ``self`` and ``other``. Used by :meth:`Sandbox.run` to apply a + per-call override that may only narrow the construction limits. + """ + if other is None: + return self + + def _min_or(a: float | None, b: float | None) -> float | None: + if a is None: + return b + if b is None: + return a + return min(a, b) + + # Tighten network: choose the strictest. + rank: dict[_NetworkPolicy, int] = { + "none": 0, + "loopback": 1, + "allowlist": 2, + "full": 3, + } + net = self.network if rank[self.network] <= rank[other.network] else other.network + return ResourceLimits( + cpu_seconds=_min_or(self.cpu_seconds, other.cpu_seconds), + wall_seconds=_min_or(self.wall_seconds, other.wall_seconds), + memory_bytes=_min_or(self.memory_bytes, other.memory_bytes), + network=net, + network_allowlist=( + tuple(set(self.network_allowlist) & set(other.network_allowlist)) + if self.network_allowlist or other.network_allowlist + else () + ), + fs_read_roots=tuple( + sorted(set(self.fs_read_roots) & set(other.fs_read_roots)) + ) + if self.fs_read_roots and other.fs_read_roots + else (self.fs_read_roots or other.fs_read_roots), + fs_write_roots=tuple( + sorted(set(self.fs_write_roots) & set(other.fs_write_roots)) + ) + if self.fs_write_roots and other.fs_write_roots + else (self.fs_write_roots or other.fs_write_roots), + max_processes=_min_or(self.max_processes, other.max_processes), + env=other.env if other.env is not None else self.env, + ) + + +@dataclass(frozen=True, slots=True) +class SandboxResult: + """Outcome of a single :meth:`Sandbox.run` invocation. + + Attributes: + stdout: Captured standard output (may be truncated). + stderr: Captured standard error (may be truncated). + exit_code: Subprocess exit status. Negative on signal. + wall_seconds: Observed wall-clock duration. + timed_out: ``True`` if the subprocess hit + :attr:`ResourceLimits.wall_seconds` before exiting. + truncated: ``True`` if stdout/stderr were truncated by an output cap. + artifacts: File contents emitted under + :attr:`ResourceLimits.fs_write_roots`, keyed by relative path. + Populated lazily by backends that support it; default empty. + """ + + stdout: str + stderr: str + exit_code: int + wall_seconds: float + timed_out: bool = False + truncated: bool = False + artifacts: Mapping[str, bytes] = field(default_factory=dict) + + +@runtime_checkable +class Sandbox(Protocol): + """An async context manager owning an isolated execution environment. + + Lifecycle: ``open()`` is idempotent and required before ``run()``; + ``close()`` releases all OS resources. Use as ``async with sandbox:`` to + bracket lifecycle automatically. + + :meth:`run` does *not* raise on tool exit codes. It raises + :class:`SandboxError` only on infrastructure failures (sandbox launch, + host kernel error). Per-call ``limits`` may only narrow construction + ``limits``; widening attempts are silently clamped. + + All paths in :meth:`write_file` / :meth:`read_file` are sandbox-virtual; + the backend is responsible for translating to host paths. + """ + + name: ClassVar[str] + limits: ResourceLimits + + async def open(self) -> None: ... + + async def close(self) -> None: ... + + async def __aenter__(self) -> Sandbox: ... + + async def __aexit__(self, exc_type, exc, tb) -> None: ... + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: ... + + async def write_file(self, path: str, data: bytes) -> None: ... + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: ... + + +__all__ = [ + "ResourceLimits", + "Sandbox", + "SandboxError", + "SandboxResult", +] diff --git a/torchrl/envs/llm/agentic/sandbox/docker.py b/torchrl/envs/llm/agentic/sandbox/docker.py new file mode 100644 index 00000000000..c7b025ed318 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/docker.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Docker / Podman sandbox backend (stub). + +Tracked in the agentic ``__init__.py`` TODO list as +"E2B / Modal / Docker real implementations." This file exists so the +import surface is stable from day one and downstream code can reference +``DockerSandbox`` even before the implementation lands. +""" +from __future__ import annotations + +from collections.abc import Sequence +from typing import ClassVar + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + + +class DockerSandbox: + """Container-based sandbox (not yet implemented).""" + + name: ClassVar[str] = "docker" + + def __init__( + self, + limits: ResourceLimits | None = None, + *, + image: str = "python:3.11-slim", + ) -> None: + self.limits = limits or ResourceLimits() + self.image = image + + async def open(self) -> None: + raise NotImplementedError( + "DockerSandbox is not yet implemented. See the TODO list in " + "torchrl/envs/llm/agentic/__init__.py and contribute! For now " + "use BubblewrapSandbox (Linux) or SeatbeltSandbox (macOS)." + ) + + async def close(self) -> None: + pass + + async def __aenter__(self) -> DockerSandbox: # pragma: no cover + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: # pragma: no cover + await self.close() + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: # pragma: no cover + raise NotImplementedError + + async def write_file(self, path: str, data: bytes) -> None: # pragma: no cover + raise NotImplementedError + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: # pragma: no cover + raise NotImplementedError + + +__all__ = ["DockerSandbox"] diff --git a/torchrl/envs/llm/agentic/sandbox/e2b.py b/torchrl/envs/llm/agentic/sandbox/e2b.py new file mode 100644 index 00000000000..7251d72c523 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/e2b.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""E2B hosted-sandbox backend (stub). + +Tracked in the agentic ``__init__.py`` TODO list. Stub kept so the import +surface is stable. +""" +from __future__ import annotations + +import importlib.util +from collections.abc import Sequence +from typing import ClassVar + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + +_has_e2b = importlib.util.find_spec("e2b") is not None + + +class E2BSandbox: + """E2B-hosted sandbox (not yet implemented).""" + + name: ClassVar[str] = "e2b" + + def __init__(self, limits: ResourceLimits | None = None) -> None: + self.limits = limits or ResourceLimits() + + async def open(self) -> None: + raise NotImplementedError( + "E2BSandbox is not yet implemented. See the TODO list in " + "torchrl/envs/llm/agentic/__init__.py." + ) + + async def close(self) -> None: + pass + + async def __aenter__(self) -> E2BSandbox: # pragma: no cover + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: # pragma: no cover + await self.close() + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: # pragma: no cover + raise NotImplementedError + + async def write_file(self, path: str, data: bytes) -> None: # pragma: no cover + raise NotImplementedError + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: # pragma: no cover + raise NotImplementedError + + +__all__ = ["E2BSandbox"] diff --git a/torchrl/envs/llm/agentic/sandbox/modal.py b/torchrl/envs/llm/agentic/sandbox/modal.py new file mode 100644 index 00000000000..8dc9fc7435e --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/modal.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Modal hosted-sandbox backend (stub). + +Tracked in the agentic ``__init__.py`` TODO list. Stub kept so the import +surface is stable. +""" +from __future__ import annotations + +import importlib.util +from collections.abc import Sequence +from typing import ClassVar + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + +_has_modal = importlib.util.find_spec("modal") is not None + + +class ModalSandbox: + """Modal-hosted sandbox (not yet implemented).""" + + name: ClassVar[str] = "modal" + + def __init__(self, limits: ResourceLimits | None = None) -> None: + self.limits = limits or ResourceLimits() + + async def open(self) -> None: + raise NotImplementedError( + "ModalSandbox is not yet implemented. See the TODO list in " + "torchrl/envs/llm/agentic/__init__.py." + ) + + async def close(self) -> None: + pass + + async def __aenter__(self) -> ModalSandbox: # pragma: no cover + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: # pragma: no cover + await self.close() + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: # pragma: no cover + raise NotImplementedError + + async def write_file(self, path: str, data: bytes) -> None: # pragma: no cover + raise NotImplementedError + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: # pragma: no cover + raise NotImplementedError + + +__all__ = ["ModalSandbox"] diff --git a/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py b/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py new file mode 100644 index 00000000000..2b3014abbc8 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Bubblewrap-backed sandbox (Linux default). + +Builds a ``bwrap`` argv prefix from the :class:`ResourceLimits` and runs the +target command inside the resulting unprivileged user namespace. This gives +us: + +- a private mount namespace (write_roots are bind-mounted RW, the rest is RO) +- a private network namespace by default (``--unshare-net``); ``"allowlist"`` + and ``"full"`` keep the host network namespace and rely on the caller to + ensure connections only succeed where allowed. +- a private PID namespace (``--unshare-pid``) +- ``--die-with-parent`` so the sandbox dies with the parent process. + +The implementation is best-effort: bubblewrap's API is large, and edge +cases (rootless overlays, nested user namespaces in some kernels) are +documented but not silently papered over. +""" +from __future__ import annotations + +import asyncio +import importlib.util +import os +import shutil +import time +from collections.abc import Sequence +from pathlib import Path +from typing import ClassVar + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + +_OUTPUT_CAP = 1 << 20 + +_has_bwrap = shutil.which("bwrap") is not None + + +class BubblewrapSandbox: + """Linux sandbox backed by ``bwrap`` (bubblewrap). + + Args: + limits: Construction-time resource limits. + bwrap_path: Override the ``bwrap`` executable path. Default uses + :func:`shutil.which`. + + Raises: + SandboxError: at :meth:`open` time if ``bwrap`` is not on ``PATH`` + and ``bwrap_path`` was not supplied. + + Example: + >>> import asyncio # doctest: +SKIP + >>> from torchrl.envs.llm.agentic.sandbox import ( + ... BubblewrapSandbox, ResourceLimits, + ... ) + >>> async def go(): + ... async with BubblewrapSandbox( + ... limits=ResourceLimits(network="none", + ... fs_write_roots=("/tmp/work",)) + ... ) as s: + ... r = await s.run(["python3", "-c", "print('hi')"]) + ... return r.stdout.strip() + """ + + name: ClassVar[str] = "bubblewrap" + + def __init__( + self, + limits: ResourceLimits | None = None, + *, + bwrap_path: str | None = None, + ) -> None: + self.limits = limits or ResourceLimits() + self._bwrap = bwrap_path or shutil.which("bwrap") + self._opened = False + + async def open(self) -> None: + if self._opened: + return + if not self._bwrap: + raise SandboxError( + "bwrap not found on PATH. Install bubblewrap " + "(apt-get install bubblewrap / dnf install bubblewrap) or " + "use UnsafeSubprocessSandbox for testing." + ) + self._opened = True + + async def close(self) -> None: + self._opened = False + + async def __aenter__(self) -> BubblewrapSandbox: + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + def _build_argv( + self, argv: Sequence[str], limits: ResourceLimits, cwd: str | None + ) -> list[str]: + bw: list[str] = [self._bwrap or "bwrap"] + bw += ["--die-with-parent", "--unshare-user", "--unshare-pid", "--unshare-ipc"] + if limits.network in ("none", "loopback"): + bw += ["--unshare-net"] + bw += ["--proc", "/proc", "--dev", "/dev"] + # Read-only bind of host root (cheap; lets the launched binary find + # its own libs). Each write root is then bind-mounted RW on top. + bw += ["--ro-bind", "/", "/"] + for root in limits.fs_write_roots: + Path(root).mkdir(parents=True, exist_ok=True) + bw += ["--bind", root, root] + if cwd: + bw += ["--chdir", cwd] + # Clear env, then re-add only what we want. + bw += ["--clearenv"] + env = limits.env or { + "PATH": os.environ.get("PATH", "/usr/bin:/bin"), + "HOME": "/tmp", + "LANG": os.environ.get("LANG", "C.UTF-8"), + } + for k, v in env.items(): + bw += ["--setenv", k, v] + # prlimit guards memory; CPU seconds we leave to wall_seconds + ulimit + # via shell-out only when memory_bytes is set. + if limits.memory_bytes is not None and shutil.which("prlimit"): + bw += [ + "prlimit", + f"--as={limits.memory_bytes}", + "--", + ] + bw += list(argv) + return bw + + def _build_env(self) -> dict[str, str]: + # bwrap sees the parent env only for argv expansion; --clearenv + + # --setenv handle the child env. Keep the parent-side env minimal. + return { + "PATH": os.environ.get("PATH", "/usr/bin:/bin"), + } + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + eff = self.limits.narrow(limits) + bw_argv = self._build_argv(argv, eff, cwd) + t0 = time.monotonic() + try: + proc = await asyncio.create_subprocess_exec( + *bw_argv, + stdin=asyncio.subprocess.PIPE if stdin is not None else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self._build_env(), + ) + except FileNotFoundError as e: + raise SandboxError(f"could not launch bwrap: {e}") from e + try: + out_b, err_b = await asyncio.wait_for( + proc.communicate(stdin), timeout=eff.wall_seconds + ) + timed_out = False + except asyncio.TimeoutError: + proc.kill() + try: + out_b, err_b = await proc.communicate() + except Exception: # pragma: no cover + out_b, err_b = b"", b"" + timed_out = True + wall = time.monotonic() - t0 + truncated = len(out_b) > _OUTPUT_CAP or len(err_b) > _OUTPUT_CAP + return SandboxResult( + stdout=out_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + stderr=err_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + exit_code=proc.returncode if proc.returncode is not None else -1, + wall_seconds=wall, + timed_out=timed_out, + truncated=truncated, + ) + + async def write_file(self, path: str, data: bytes) -> None: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + # Writes happen on the host side at a path that will be bind-mounted + # RW into the sandbox. Verify the target lies under a write root. + if not any( + os.path.commonpath([path, r]) == r for r in self.limits.fs_write_roots + ): + raise SandboxError( + f"refusing to write to {path!r}: outside fs_write_roots " + f"{self.limits.fs_write_roots!r}" + ) + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(data) + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + b = Path(path).read_bytes() + if max_bytes is not None and len(b) > max_bytes: + return b[:max_bytes] + return b + + +__all__ = ["BubblewrapSandbox"] diff --git a/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py b/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py new file mode 100644 index 00000000000..62a66c7e6d4 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""macOS sandbox-exec backend (seatbelt). + +Generates a small Scheme profile from the :class:`ResourceLimits` and runs +the target command via ``sandbox-exec -p --``. Matches the +isolation guarantees of bubblewrap to the extent macOS allows: filesystem +read/write restrictions and full network deny. + +.. note:: + Apple has officially deprecated ``sandbox-exec``, but the binary still + ships with macOS 14+ and works for our purposes. Where stronger + guarantees are needed (or for portability across CI platforms) prefer a + container backend (Docker stub today). +""" +from __future__ import annotations + +import asyncio +import os +import shutil +import time +from collections.abc import Sequence +from pathlib import Path +from typing import ClassVar + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + +_OUTPUT_CAP = 1 << 20 + +_has_sandbox_exec = shutil.which("sandbox-exec") is not None + + +def _profile(limits: ResourceLimits) -> str: + """Build a sandbox-exec Scheme profile from ``limits``.""" + lines: list[str] = [ + "(version 1)", + "(deny default)", + "(allow process-fork)", + "(allow process-exec)", + "(allow signal (target self))", + "(allow sysctl-read)", + "(allow file-read*)", # readable host root by default + "(allow mach-lookup)", + "(allow ipc-posix-shm)", + ] + if limits.network in ("none", "loopback"): + lines.append("(deny network*)") + else: + lines.append("(allow network*)") + if limits.fs_write_roots: + # Allow writes only under the named roots. + for root in limits.fs_write_roots: + lines.append(f'(allow file-write* (subpath "{root}"))') + # /private/var, /tmp need write for many runtimes; allow only if user + # explicitly listed them. + return "\n".join(lines) + + +class SeatbeltSandbox: + """macOS sandbox backed by ``sandbox-exec``. + + Args: + limits: Construction-time resource limits. + + Raises: + SandboxError: at :meth:`open` if ``sandbox-exec`` is not available. + """ + + name: ClassVar[str] = "seatbelt" + + def __init__(self, limits: ResourceLimits | None = None) -> None: + self.limits = limits or ResourceLimits() + self._exec = shutil.which("sandbox-exec") + self._opened = False + + async def open(self) -> None: + if self._opened: + return + if not self._exec: + raise SandboxError( + "sandbox-exec not found. SeatbeltSandbox requires macOS." + ) + self._opened = True + + async def close(self) -> None: + self._opened = False + + async def __aenter__(self) -> SeatbeltSandbox: + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + def _build_argv( + self, + argv: Sequence[str], + limits: ResourceLimits, + ) -> list[str]: + return [self._exec or "sandbox-exec", "-p", _profile(limits), "--", *argv] + + def _build_env(self, limits: ResourceLimits) -> dict[str, str]: + if limits.env is None: + return { + "PATH": os.environ.get("PATH", "/usr/bin:/bin"), + "HOME": os.environ.get("HOME", "/tmp"), + "LANG": os.environ.get("LANG", "C.UTF-8"), + } + return dict(limits.env) + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + eff = self.limits.narrow(limits) + sb_argv = self._build_argv(argv, eff) + t0 = time.monotonic() + try: + proc = await asyncio.create_subprocess_exec( + *sb_argv, + stdin=asyncio.subprocess.PIPE if stdin is not None else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=self._build_env(eff), + ) + except FileNotFoundError as e: + raise SandboxError(f"could not launch sandbox-exec: {e}") from e + try: + out_b, err_b = await asyncio.wait_for( + proc.communicate(stdin), timeout=eff.wall_seconds + ) + timed_out = False + except asyncio.TimeoutError: + proc.kill() + try: + out_b, err_b = await proc.communicate() + except Exception: # pragma: no cover + out_b, err_b = b"", b"" + timed_out = True + wall = time.monotonic() - t0 + truncated = len(out_b) > _OUTPUT_CAP or len(err_b) > _OUTPUT_CAP + return SandboxResult( + stdout=out_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + stderr=err_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + exit_code=proc.returncode if proc.returncode is not None else -1, + wall_seconds=wall, + timed_out=timed_out, + truncated=truncated, + ) + + async def write_file(self, path: str, data: bytes) -> None: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + if not any( + os.path.commonpath([path, r]) == r for r in self.limits.fs_write_roots + ): + raise SandboxError( + f"refusing to write to {path!r}: outside fs_write_roots " + f"{self.limits.fs_write_roots!r}" + ) + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(data) + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + b = Path(path).read_bytes() + if max_bytes is not None and len(b) > max_bytes: + return b[:max_bytes] + return b + + +__all__ = ["SeatbeltSandbox"] diff --git a/torchrl/envs/llm/agentic/sandbox/unsafe.py b/torchrl/envs/llm/agentic/sandbox/unsafe.py new file mode 100644 index 00000000000..883111c9ba6 --- /dev/null +++ b/torchrl/envs/llm/agentic/sandbox/unsafe.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unsandboxed subprocess backend (testing / fallback only). + +Runs argv directly via :func:`asyncio.create_subprocess_exec` with no +isolation. Emits ``UserWarning`` on every :meth:`open` call so the lack of +containment is impossible to miss in a real deployment. +""" +from __future__ import annotations + +import asyncio +import os +import shutil +import time +import warnings +from collections.abc import Sequence +from pathlib import Path +from typing import ClassVar + +from torchrl._utils import logger as torchrl_logger + +from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult + +_OUTPUT_CAP = 1 << 20 # 1 MiB per stream + + +class UnsafeSubprocessSandbox: + """Bare ``asyncio.create_subprocess_exec`` with no isolation. + + Useful for unit tests and environments where neither bubblewrap nor + sandbox-exec is available. **Not a security boundary.** Emits a + :class:`UserWarning` on every :meth:`open` so this is loud. + + The ``limits.fs_write_roots`` and ``limits.network`` policies are *not + enforced* by this backend; pass them anyway so tests can switch to + :class:`BubblewrapSandbox` or :class:`SeatbeltSandbox` without code + changes. + + Examples: + >>> import asyncio + >>> async def go(): + ... async with UnsafeSubprocessSandbox() as s: + ... r = await s.run(["echo", "hi"]) + ... return r.stdout.strip() + >>> asyncio.run(go()) # doctest: +SKIP + 'hi' + """ + + name: ClassVar[str] = "unsafe-subprocess" + + def __init__(self, limits: ResourceLimits | None = None) -> None: + self.limits = limits or ResourceLimits() + self._opened = False + + async def open(self) -> None: + if self._opened: + return + warnings.warn( + "UnsafeSubprocessSandbox provides NO isolation. Do not use it " + "with untrusted model output in production. Switch to " + "BubblewrapSandbox (Linux) or SeatbeltSandbox (macOS).", + UserWarning, + stacklevel=2, + ) + self._opened = True + + async def close(self) -> None: + self._opened = False + + async def __aenter__(self) -> UnsafeSubprocessSandbox: + await self.open() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + def _build_env(self, limits: ResourceLimits) -> dict[str, str]: + if limits.env is None: + base = { + "PATH": os.environ.get("PATH", "/usr/bin:/bin"), + "HOME": os.environ.get("HOME", "/tmp"), + "LANG": os.environ.get("LANG", "C.UTF-8"), + } + return base + return dict(limits.env) + + async def run( + self, + argv: Sequence[str], + *, + stdin: bytes | None = None, + cwd: str | None = None, + limits: ResourceLimits | None = None, + ) -> SandboxResult: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + eff = self.limits.narrow(limits) + env = self._build_env(eff) + t0 = time.monotonic() + try: + proc = await asyncio.create_subprocess_exec( + *argv, + stdin=asyncio.subprocess.PIPE if stdin is not None else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + except FileNotFoundError as e: + raise SandboxError(f"could not launch subprocess: {e}") from e + try: + out_b, err_b = await asyncio.wait_for( + proc.communicate(stdin), + timeout=eff.wall_seconds, + ) + timed_out = False + except asyncio.TimeoutError: + proc.kill() + try: + out_b, err_b = await proc.communicate() + except Exception: # pragma: no cover -- defensive + out_b, err_b = b"", b"" + timed_out = True + wall = time.monotonic() - t0 + truncated = len(out_b) > _OUTPUT_CAP or len(err_b) > _OUTPUT_CAP + if truncated: + torchrl_logger.warning( + "UnsafeSubprocessSandbox truncated subprocess output (cap=%d)", + _OUTPUT_CAP, + ) + return SandboxResult( + stdout=out_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + stderr=err_b[:_OUTPUT_CAP].decode("utf-8", errors="replace"), + exit_code=proc.returncode if proc.returncode is not None else -1, + wall_seconds=wall, + timed_out=timed_out, + truncated=truncated, + ) + + async def write_file(self, path: str, data: bytes) -> None: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + Path(path).parent.mkdir(parents=True, exist_ok=True) + Path(path).write_bytes(data) + + async def read_file( + self, path: str, max_bytes: int | None = None + ) -> bytes: + if not self._opened: + raise SandboxError("sandbox is not open; call open() first") + b = Path(path).read_bytes() + if max_bytes is not None and len(b) > max_bytes: + return b[:max_bytes] + return b + + +def _shutil_which(name: str) -> str | None: + return shutil.which(name) + + +__all__ = ["UnsafeSubprocessSandbox"] diff --git a/torchrl/envs/llm/agentic/schema.py b/torchrl/envs/llm/agentic/schema.py new file mode 100644 index 00000000000..1a92c3f71f8 --- /dev/null +++ b/torchrl/envs/llm/agentic/schema.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""JSON Schema helpers for :class:`~torchrl.envs.llm.agentic.Tool`. + +Tools declare ``input_schema`` as a plain JSON Schema dict (matches OpenAI, +Anthropic, and MCP). A small ``validate_args`` helper enforces required +fields and primitive types without pulling in a full JSON Schema validator. +For users who prefer pydantic, :func:`json_schema_from_pydantic` converts a +``BaseModel`` subclass to the equivalent dict. +""" +from __future__ import annotations + +import importlib.util +from collections.abc import Mapping +from typing import Any + +_has_pydantic = importlib.util.find_spec("pydantic") is not None + + +_TYPE_MAP: dict[str, type | tuple[type, ...]] = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict, + "null": type(None), +} + + +class SchemaValidationError(ValueError): + """Raised by :func:`validate_args` on a schema mismatch.""" + + +def validate_args( + args: Mapping[str, Any], schema: Mapping[str, Any] | None +) -> None: + """Validate ``args`` against a JSON Schema dict. + + Implements the subset that matters for tool-call dispatch: + + - top-level ``type: object``, + - ``required`` field presence, + - per-property ``type`` (single string, not the array form). + + Anything else is permitted. Tools that need richer validation should + do it inside :meth:`Tool.run` (or use pydantic via + :func:`json_schema_from_pydantic`). + + Raises: + SchemaValidationError: on missing required fields or type mismatches. + """ + if not schema: + return + if schema.get("type") not in (None, "object"): + return + required = schema.get("required") or () + for key in required: + if key not in args: + raise SchemaValidationError(f"missing required argument: {key!r}") + props: Mapping[str, Any] = schema.get("properties") or {} + for key, sub in props.items(): + if key not in args: + continue + expected = sub.get("type") + if not expected: + continue + py_type = _TYPE_MAP.get(expected) + if py_type is None: + continue + if not isinstance(args[key], py_type): + raise SchemaValidationError( + f"argument {key!r} expected JSON type {expected!r}, " + f"got {type(args[key]).__name__}" + ) + + +def json_schema_from_pydantic(model: Any) -> dict[str, Any]: + """Return the JSON Schema dict for a ``pydantic.BaseModel`` subclass. + + Equivalent to ``model.model_json_schema()`` (pydantic v2). Raises + ``ImportError`` if pydantic isn't installed. + + Examples: + >>> from pydantic import BaseModel # doctest: +SKIP + >>> class Args(BaseModel): + ... code: str + >>> json_schema_from_pydantic(Args) # doctest: +SKIP + {'type': 'object', 'properties': {'code': {'type': 'string'}}, ...} + """ + if not _has_pydantic: + raise ImportError( + "pydantic is not installed. Install pydantic or pass a JSON " + "Schema dict directly to your Tool's input_schema." + ) + if hasattr(model, "model_json_schema"): + return model.model_json_schema() + raise TypeError( + f"{model!r} is not a pydantic v2 BaseModel subclass." + ) From dea83b541edd07e916ec0701728ecb0800c25e7e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 May 2026 08:19:55 +0100 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- docs/source/reference/llms_envs.rst | 33 +- test/llm/test_agentic.py | 433 +++++++++++++++++ test/llm/test_llm_transforms.py | 446 ------------------ torchrl/envs/llm/agentic/__init__.py | 13 +- torchrl/envs/llm/agentic/parsers/anthropic.py | 8 +- .../envs/llm/agentic/parsers/json_block.py | 8 +- torchrl/envs/llm/agentic/parsers/openai.py | 18 +- torchrl/envs/llm/agentic/parsers/xml.py | 10 +- torchrl/envs/llm/agentic/protocols.py | 164 ++++--- torchrl/envs/llm/agentic/repl/__init__.py | 3 +- torchrl/envs/llm/agentic/repl/base.py | 81 ++-- torchrl/envs/llm/agentic/repl/jupyter.py | 8 +- torchrl/envs/llm/agentic/repl/subprocess.py | 21 +- torchrl/envs/llm/agentic/sandbox/__init__.py | 5 +- torchrl/envs/llm/agentic/sandbox/base.py | 172 ++++--- torchrl/envs/llm/agentic/sandbox/docker.py | 4 +- torchrl/envs/llm/agentic/sandbox/e2b.py | 4 +- torchrl/envs/llm/agentic/sandbox/modal.py | 4 +- .../llm/agentic/sandbox/subprocess_bwrap.py | 11 +- .../agentic/sandbox/subprocess_seatbelt.py | 10 +- torchrl/envs/llm/agentic/sandbox/unsafe.py | 10 +- torchrl/envs/llm/agentic/schema.py | 8 +- 22 files changed, 753 insertions(+), 721 deletions(-) create mode 100644 test/llm/test_agentic.py diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst index 879ba3e6d4f..5fa460116e3 100644 --- a/docs/source/reference/llms_envs.rst +++ b/docs/source/reference/llms_envs.rst @@ -35,9 +35,32 @@ Agentic toolkit (preview) .. currentmodule:: torchrl.envs.llm.agentic The :mod:`torchrl.envs.llm.agentic` package provides a SOTA, async-first -substrate for tool-calling agents. The headline orchestrator -(``ToolCompose``) lands in a follow-up commit; this preview ships the -contracts, parsers, sandboxing, and stateful REPLs that it builds on. +substrate for tool-calling agents on top of an unmodified +:class:`~torchrl.envs.llm.ChatEnv`: structured parsers for the major +provider protocols (XML, JSON-block, OpenAI ``tool_calls``, Anthropic +``tool_use``), hardened :class:`Sandbox` backends, and stateful +:class:`Repl` sessions. + +This preview ships the substrate the headline orchestrator +(``ToolCompose``) is built on. A minimal end-to-end sketch -- usable +today against the substrate, formalised by the orchestrator -- looks +like: + +.. code-block:: python + + from torchrl.envs.llm.agentic.parsers import XMLToolCallParser + from torchrl.envs.llm.agentic.sandbox import default_sandbox, ResourceLimits + from torchrl.envs.llm.agentic.repl import SubprocessRepl + + parser = XMLToolCallParser() + parsed = parser.parse('{"code": "print(2+2)"}') + # -> parsed.calls[0].tool == "python", parsed.calls[0].call_id == "c1" + + sandbox = default_sandbox(ResourceLimits(wall_seconds=10, network="none")) + async def run(): + async with sandbox, SubprocessRepl(sandbox) as repl: + result = await repl.execute("print(2+2)") + assert result.stdout.strip() == "4" Tool contracts ~~~~~~~~~~~~~~ @@ -96,8 +119,8 @@ provides a no-isolation fallback that warns loudly on every Apple has officially deprecated ``sandbox-exec``, but it still ships with macOS 14+ and remains the most portable in-process isolation primitive on that platform. For stronger or cross-platform - isolation, prefer :class:`DockerSandbox` (real implementation - tracked in the package TODO list). + isolation, prefer :class:`DockerSandbox` (currently a stub -- + contributions welcome). .. autosummary:: :toctree: generated/ diff --git a/test/llm/test_agentic.py b/test/llm/test_agentic.py new file mode 100644 index 00000000000..6816ac56bd9 --- /dev/null +++ b/test/llm/test_agentic.py @@ -0,0 +1,433 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the agentic toolkit under :mod:`torchrl.envs.llm.agentic`. + +Split out from ``test/llm/test_llm_transforms.py``: the new package is +large enough -- parsers, sandbox backends, REPLs, and (in follow-up +commits) ToolCompose plus built-in tools -- that bundling its tests +into the legacy file made discovery awkward. +""" +from __future__ import annotations + +import asyncio +import json +import sys +import warnings + +import pytest +from tensordict import set_list_to_stack + +from torchrl.envs.llm.agentic import ( + ParsedCall, + Tool, + ToolCallParser, + ToolResult, + validate_args, +) +from torchrl.envs.llm.agentic.parsers import ( + AnthropicToolUseParser, + JSONToolCallParser, + OpenAIToolCallParser, + XMLToolCallParser, +) +from torchrl.envs.llm.agentic.repl import _has_jupyter_client, SubprocessRepl +from torchrl.envs.llm.agentic.sandbox import ( + BubblewrapSandbox, + default_sandbox, + ResourceLimits, + SeatbeltSandbox, + UnsafeSubprocessSandbox, +) +from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap +from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import _has_sandbox_exec + + +@pytest.fixture(scope="module", autouse=True) +def list_to_stack_fixture(): + with set_list_to_stack(True): + yield + + +def _run(coro): + return asyncio.run(coro) + + +class TestAgenticParsers: + """Per-parser conformance: parse, render_call round-trip, render_result, + stable call_id (parser-supplied or assigned). + """ + + @pytest.mark.parametrize( + "parser_cls", + [ + XMLToolCallParser, + JSONToolCallParser, + OpenAIToolCallParser, + AnthropicToolUseParser, + ], + ) + def test_implements_protocol(self, parser_cls): + p = parser_cls() + assert isinstance(p, ToolCallParser) + assert isinstance(p.name, str) and p.name + + def test_xml_parse_and_call_id(self): + p = XMLToolCallParser() + r = p.parse('{"text": "hi"}tail') + assert len(r.calls) == 1 + c = r.calls[0] + assert c.tool == "echo" + assert c.args == {"text": "hi"} + assert c.call_id == "t1" # tag becomes call_id when present + assert c.tag == "t1" + assert r.text == "tail" + + def test_xml_assigns_call_id_when_no_tag(self): + p = XMLToolCallParser() + r = p.parse('{}') + assert r.calls[0].call_id # non-empty + assert r.calls[0].tag is None + + def test_xml_round_trip(self): + p = XMLToolCallParser() + call = ParsedCall(tool="echo", args={"text": "hi"}, call_id="abc", tag="abc") + rendered = p.render_call(call) + re_parsed = p.parse(rendered) + assert re_parsed.calls[0].tool == "echo" + assert re_parsed.calls[0].args == {"text": "hi"} + assert re_parsed.calls[0].call_id == "abc" + + def test_xml_render_result(self): + p = XMLToolCallParser() + msg = p.render_result("c1", ToolResult.from_text("output")) + assert msg["role"] == "tool" + assert "c1" in msg["content"] + assert "output" in msg["content"] + + def test_json_block_parse_with_id(self): + p = JSONToolCallParser() + resp = json.dumps( + { + "message": "ok", + "tools": [{"tool": "echo", "args": {"x": 1}, "id": "j1"}], + } + ) + r = p.parse(resp) + assert r.text == "ok" + assert r.calls[0].tool == "echo" + assert r.calls[0].args == {"x": 1} + assert r.calls[0].call_id == "j1" + + def test_json_block_assigns_call_id(self): + p = JSONToolCallParser() + resp = json.dumps({"message": "", "tools": [{"tool": "x", "args": {}}]}) + r = p.parse(resp) + assert r.calls[0].call_id # uuid hex + + def test_json_block_invalid_json_falls_back_to_text(self): + p = JSONToolCallParser() + r = p.parse("not json at all") + assert r.text == "not json at all" + assert r.calls == () + + def test_openai_preserves_id_and_decodes_args(self): + p = OpenAIToolCallParser() + r = p.parse( + { + "role": "assistant", + "content": "thinking", + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q": "torchrl"}', + }, + } + ], + } + ) + assert r.calls[0].tool == "search" + assert r.calls[0].args == {"q": "torchrl"} + assert r.calls[0].call_id == "call_a" + + def test_openai_render_result_uses_tool_call_id(self): + p = OpenAIToolCallParser() + msg = p.render_result("call_a", ToolResult.from_text("done")) + assert msg["role"] == "tool" + assert msg["tool_call_id"] == "call_a" + assert msg["content"] == "done" + + def test_anthropic_extracts_text_and_tool_use(self): + p = AnthropicToolUseParser() + r = p.parse( + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search."}, + { + "type": "tool_use", + "id": "toolu_a", + "name": "search", + "input": {"q": "x"}, + }, + ], + } + ) + assert r.text == "Let me search." + assert r.calls[0].tool == "search" + assert r.calls[0].args == {"q": "x"} + assert r.calls[0].call_id == "toolu_a" + + def test_anthropic_render_result_uses_tool_use_id(self): + p = AnthropicToolUseParser() + msg = p.render_result("toolu_a", ToolResult.from_text("hit", is_error=False)) + assert msg["role"] == "user" + assert msg["content"][0]["type"] == "tool_result" + assert msg["content"][0]["tool_use_id"] == "toolu_a" + + def test_validate_args_required(self): + schema = { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + } + validate_args({"code": "print(1)"}, schema) + with pytest.raises(Exception): + validate_args({}, schema) + + def test_validate_args_type_mismatch(self): + schema = { + "type": "object", + "properties": {"n": {"type": "integer"}}, + } + validate_args({"n": 3}, schema) + with pytest.raises(Exception): + validate_args({"n": "three"}, schema) + + def test_tool_protocol_runtime_check(self): + class _T: + name = "t" + description = "d" + input_schema = {"type": "object", "properties": {}} + output_schema = None + wants_state = False + + async def run(self, args, ctx): + return ToolResult.from_text("ok") + + async def setup(self): + pass + + async def teardown(self): + pass + + assert isinstance(_T(), Tool) + + +class TestAgenticSandbox: + """Sandbox protocol conformance + sandbox-escape negatives.""" + + def test_unsafe_warns_on_open(self): + async def go(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + async with UnsafeSubprocessSandbox() as _s: + pass + assert any(issubclass(w.category, UserWarning) for w in caught) + + _run(go()) + + def test_unsafe_runs_simple_command(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=5)) as s: + r = await s.run(["/bin/echo", "hello"]) + assert r.exit_code == 0 + assert r.stdout.strip() == "hello" + assert not r.timed_out + + _run(go()) + + def test_unsafe_timeout(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=0.2)) as s: + r = await s.run(["/bin/sleep", "5"]) + assert r.timed_out + + _run(go()) + + def test_resource_limits_narrow(self): + a = ResourceLimits(wall_seconds=10, network="full") + b = ResourceLimits(wall_seconds=2, network="none") + c = a.narrow(b) + assert c.wall_seconds == 2 + assert c.network == "none" + # Reverse direction: narrow keeps the strictest. + c2 = b.narrow(a) + assert c2.wall_seconds == 2 + assert c2.network == "none" + + def test_default_sandbox_picks_platform(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = default_sandbox() + if sys.platform.startswith("linux") and _has_bwrap: + assert isinstance(s, BubblewrapSandbox) + elif sys.platform == "darwin" and _has_sandbox_exec: + assert isinstance(s, SeatbeltSandbox) + else: + assert isinstance(s, UnsafeSubprocessSandbox) + + @pytest.mark.skipif( + not (sys.platform.startswith("linux") and _has_bwrap), + reason="bubblewrap not available", + ) + def test_bubblewrap_blocks_fs_escape(self, tmp_path): + """Writes outside fs_write_roots must fail.""" + write_root = tmp_path / "work" + write_root.mkdir() + outside = tmp_path / "forbidden" + + async def go(): + limits = ResourceLimits( + wall_seconds=5, + network="none", + fs_write_roots=(str(write_root),), + ) + async with BubblewrapSandbox(limits=limits) as s: + # Inside the write root: must succeed. + inside_path = write_root / "inside.txt" + r = await s.run( + [ + "/bin/sh", + "-c", + f"echo hi > {inside_path}", + ] + ) + assert r.exit_code == 0 + assert inside_path.read_text().strip() == "hi" + # Outside the write root: must fail. + r2 = await s.run( + [ + "/bin/sh", + "-c", + f"echo nope > {outside}", + ] + ) + assert r2.exit_code != 0 + assert not outside.exists() + + _run(go()) + + @pytest.mark.skipif( + not (sys.platform.startswith("linux") and _has_bwrap), + reason="bubblewrap not available", + ) + def test_bubblewrap_blocks_network(self): + """network='none' must block outbound TCP.""" + + async def go(): + limits = ResourceLimits(wall_seconds=5, network="none") + async with BubblewrapSandbox(limits=limits) as s: + r = await s.run( + [ + "python3", + "-c", + "import socket; " + "socket.create_connection(('1.1.1.1', 80), timeout=2)", + ] + ) + assert r.exit_code != 0 + + _run(go()) + + @pytest.mark.skipif( + not (sys.platform == "darwin" and _has_sandbox_exec), + reason="sandbox-exec not available", + ) + def test_seatbelt_blocks_fs_escape(self, tmp_path): + write_root = tmp_path / "work" + write_root.mkdir() + outside = tmp_path / "forbidden" + + async def go(): + limits = ResourceLimits( + wall_seconds=5, + network="none", + fs_write_roots=(str(write_root),), + ) + async with SeatbeltSandbox(limits=limits) as s: + inside_path = write_root / "inside.txt" + r = await s.run(["/bin/sh", "-c", f"echo hi > {inside_path}"]) + assert r.exit_code == 0 + r2 = await s.run(["/bin/sh", "-c", f"echo nope > {outside}"]) + assert r2.exit_code != 0 + assert not outside.exists() + + _run(go()) + + +class TestAgenticRepl: + """REPL state, error capture, restart, timeout.""" + + def test_subprocess_repl_state_persists(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=10)) as s: + async with SubprocessRepl(s) as r: + r1 = await r.execute("x = 41") + assert r1.error is None + r2 = await r.execute("print(x + 1)") + assert r2.error is None + assert r2.stdout.strip() == "42" + + _run(go()) + + def test_subprocess_repl_captures_errors(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=10)) as s: + async with SubprocessRepl(s) as r: + res = await r.execute("1/0") + assert res.error is not None + assert res.error.ename == "ZeroDivisionError" + + _run(go()) + + def test_subprocess_repl_restart_clears_state(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=10)) as s: + async with SubprocessRepl(s) as r: + await r.execute("y = 99") + await r.restart() + res = await r.execute("print(y)") + assert res.error is not None # NameError + + _run(go()) + + def test_subprocess_repl_timeout(self): + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=10)) as s: + async with SubprocessRepl(s) as r: + res = await r.execute("import time; time.sleep(5)", timeout=0.3) + assert res.timed_out + + _run(go()) + + @pytest.mark.skipif(not _has_jupyter_client, reason="jupyter_client not installed") + @pytest.mark.slow + def test_jupyter_repl_state_persists(self): + from torchrl.envs.llm.agentic.repl import JupyterRepl + + async def go(): + async with UnsafeSubprocessSandbox(ResourceLimits(wall_seconds=60)) as s: + async with JupyterRepl(s) as r: + r1 = await r.execute("x = 41", timeout=30) + assert r1.error is None, r1 + r2 = await r.execute("print(x + 1)", timeout=30) + assert r2.error is None + assert r2.stdout.strip() == "42" + + _run(go()) diff --git a/test/llm/test_llm_transforms.py b/test/llm/test_llm_transforms.py index 99e52973efd..b6d0c1a6853 100644 --- a/test/llm/test_llm_transforms.py +++ b/test/llm/test_llm_transforms.py @@ -718,449 +718,3 @@ def test_empty_history_handling(self, tokenizer): assert ("tokens", "prompt") in result.keys(True, True) tokens = result.get(("tokens", "prompt"), as_list=True) assert tokens[0].numel() > 0 - - -# --------------------------------------------------------------------------- -# Agentic toolkit (torchrl.envs.llm.agentic) -# --------------------------------------------------------------------------- - -import asyncio # noqa: E402 -import socket # noqa: E402 -import sys # noqa: E402 -import warnings # noqa: E402 - -from torchrl.envs.llm.agentic import ( # noqa: E402 - ParsedCall, - TextPart, - Tool, - ToolCallParser, - ToolContext, - ToolResult, - validate_args, -) -from torchrl.envs.llm.agentic.parsers import ( # noqa: E402 - AnthropicToolUseParser, - JSONToolCallParser, - OpenAIToolCallParser, - XMLToolCallParser, -) -from torchrl.envs.llm.agentic.repl import ( # noqa: E402 - SubprocessRepl, - _has_jupyter_client, -) -from torchrl.envs.llm.agentic.sandbox import ( # noqa: E402 - BubblewrapSandbox, - ResourceLimits, - SandboxError, - SeatbeltSandbox, - UnsafeSubprocessSandbox, - default_sandbox, -) -from torchrl.envs.llm.agentic.sandbox.subprocess_bwrap import _has_bwrap # noqa: E402 -from torchrl.envs.llm.agentic.sandbox.subprocess_seatbelt import ( # noqa: E402 - _has_sandbox_exec, -) - - -def _run(coro): - return asyncio.get_event_loop().run_until_complete(coro) if False else asyncio.run(coro) - - -class TestAgenticParsers: - """Per-parser conformance: parse, render_call round-trip, render_result, - stable call_id (parser-supplied or assigned). - """ - - @pytest.mark.parametrize( - "parser_cls", - [XMLToolCallParser, JSONToolCallParser, OpenAIToolCallParser, AnthropicToolUseParser], - ) - def test_implements_protocol(self, parser_cls): - p = parser_cls() - assert isinstance(p, ToolCallParser) - assert isinstance(p.name, str) and p.name - - def test_xml_parse_and_call_id(self): - p = XMLToolCallParser() - r = p.parse('{"text": "hi"}tail') - assert len(r.calls) == 1 - c = r.calls[0] - assert c.tool == "echo" - assert c.args == {"text": "hi"} - assert c.call_id == "t1" # tag becomes call_id when present - assert c.tag == "t1" - assert r.text == "tail" - - def test_xml_assigns_call_id_when_no_tag(self): - p = XMLToolCallParser() - r = p.parse('{}') - assert r.calls[0].call_id # non-empty - assert r.calls[0].tag is None - - def test_xml_round_trip(self): - p = XMLToolCallParser() - call = ParsedCall( - tool="echo", args={"text": "hi"}, call_id="abc", tag="abc" - ) - rendered = p.render_call(call) - re_parsed = p.parse(rendered) - assert re_parsed.calls[0].tool == "echo" - assert re_parsed.calls[0].args == {"text": "hi"} - assert re_parsed.calls[0].call_id == "abc" - - def test_xml_render_result(self): - p = XMLToolCallParser() - msg = p.render_result("c1", ToolResult.from_text("output")) - assert msg["role"] == "tool" - assert "c1" in msg["content"] - assert "output" in msg["content"] - - def test_json_block_parse_with_id(self): - p = JSONToolCallParser() - resp = json.dumps( - { - "message": "ok", - "tools": [{"tool": "echo", "args": {"x": 1}, "id": "j1"}], - } - ) - r = p.parse(resp) - assert r.text == "ok" - assert r.calls[0].tool == "echo" - assert r.calls[0].args == {"x": 1} - assert r.calls[0].call_id == "j1" - - def test_json_block_assigns_call_id(self): - p = JSONToolCallParser() - resp = json.dumps({"message": "", "tools": [{"tool": "x", "args": {}}]}) - r = p.parse(resp) - assert r.calls[0].call_id # uuid hex - - def test_json_block_invalid_json_falls_back_to_text(self): - p = JSONToolCallParser() - r = p.parse("not json at all") - assert r.text == "not json at all" - assert r.calls == () - - def test_openai_preserves_id_and_decodes_args(self): - p = OpenAIToolCallParser() - r = p.parse( - { - "role": "assistant", - "content": "thinking", - "tool_calls": [ - { - "id": "call_a", - "type": "function", - "function": { - "name": "search", - "arguments": '{"q": "torchrl"}', - }, - } - ], - } - ) - assert r.calls[0].tool == "search" - assert r.calls[0].args == {"q": "torchrl"} - assert r.calls[0].call_id == "call_a" - - def test_openai_render_result_uses_tool_call_id(self): - p = OpenAIToolCallParser() - msg = p.render_result("call_a", ToolResult.from_text("done")) - assert msg["role"] == "tool" - assert msg["tool_call_id"] == "call_a" - assert msg["content"] == "done" - - def test_anthropic_extracts_text_and_tool_use(self): - p = AnthropicToolUseParser() - r = p.parse( - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Let me search."}, - { - "type": "tool_use", - "id": "toolu_a", - "name": "search", - "input": {"q": "x"}, - }, - ], - } - ) - assert r.text == "Let me search." - assert r.calls[0].tool == "search" - assert r.calls[0].args == {"q": "x"} - assert r.calls[0].call_id == "toolu_a" - - def test_anthropic_render_result_uses_tool_use_id(self): - p = AnthropicToolUseParser() - msg = p.render_result( - "toolu_a", ToolResult.from_text("hit", is_error=False) - ) - assert msg["role"] == "user" - assert msg["content"][0]["type"] == "tool_result" - assert msg["content"][0]["tool_use_id"] == "toolu_a" - - def test_validate_args_required(self): - schema = { - "type": "object", - "properties": {"code": {"type": "string"}}, - "required": ["code"], - } - validate_args({"code": "print(1)"}, schema) - with pytest.raises(Exception): - validate_args({}, schema) - - def test_validate_args_type_mismatch(self): - schema = { - "type": "object", - "properties": {"n": {"type": "integer"}}, - } - validate_args({"n": 3}, schema) - with pytest.raises(Exception): - validate_args({"n": "three"}, schema) - - def test_tool_protocol_runtime_check(self): - class _T: - name = "t" - description = "d" - input_schema = {"type": "object", "properties": {}} - output_schema = None - wants_state = False - - async def run(self, args, ctx): - return ToolResult.from_text("ok") - - async def setup(self): - pass - - async def teardown(self): - pass - - assert isinstance(_T(), Tool) - - -class TestAgenticSandbox: - """Sandbox protocol conformance + sandbox-escape negatives.""" - - def test_unsafe_warns_on_open(self): - async def go(): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - async with UnsafeSubprocessSandbox() as _s: - pass - assert any( - issubclass(w.category, UserWarning) for w in caught - ) - - _run(go()) - - def test_unsafe_runs_simple_command(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=5) - ) as s: - r = await s.run(["/bin/echo", "hello"]) - assert r.exit_code == 0 - assert r.stdout.strip() == "hello" - assert not r.timed_out - - _run(go()) - - def test_unsafe_timeout(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=0.2) - ) as s: - r = await s.run(["/bin/sleep", "5"]) - assert r.timed_out - - _run(go()) - - def test_resource_limits_narrow(self): - a = ResourceLimits(wall_seconds=10, network="full") - b = ResourceLimits(wall_seconds=2, network="none") - c = a.narrow(b) - assert c.wall_seconds == 2 - assert c.network == "none" - # Reverse direction: narrow keeps the strictest. - c2 = b.narrow(a) - assert c2.wall_seconds == 2 - assert c2.network == "none" - - def test_default_sandbox_picks_platform(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - s = default_sandbox() - if sys.platform.startswith("linux") and _has_bwrap: - assert isinstance(s, BubblewrapSandbox) - elif sys.platform == "darwin" and _has_sandbox_exec: - assert isinstance(s, SeatbeltSandbox) - else: - assert isinstance(s, UnsafeSubprocessSandbox) - - @pytest.mark.skipif( - not (sys.platform.startswith("linux") and _has_bwrap), - reason="bubblewrap not available", - ) - def test_bubblewrap_blocks_fs_escape(self, tmp_path): - """Writes outside fs_write_roots must fail.""" - write_root = tmp_path / "work" - write_root.mkdir() - outside = tmp_path / "forbidden" - - async def go(): - limits = ResourceLimits( - wall_seconds=5, - network="none", - fs_write_roots=(str(write_root),), - ) - async with BubblewrapSandbox(limits=limits) as s: - # Inside the write root: must succeed. - inside_path = write_root / "inside.txt" - r = await s.run( - [ - "/bin/sh", - "-c", - f"echo hi > {inside_path}", - ] - ) - assert r.exit_code == 0 - assert inside_path.read_text().strip() == "hi" - # Outside the write root: must fail. - r2 = await s.run( - [ - "/bin/sh", - "-c", - f"echo nope > {outside}", - ] - ) - assert r2.exit_code != 0 - assert not outside.exists() - - _run(go()) - - @pytest.mark.skipif( - not (sys.platform.startswith("linux") and _has_bwrap), - reason="bubblewrap not available", - ) - def test_bubblewrap_blocks_network(self): - """network='none' must block outbound TCP.""" - async def go(): - limits = ResourceLimits(wall_seconds=5, network="none") - async with BubblewrapSandbox(limits=limits) as s: - r = await s.run( - [ - "python3", - "-c", - "import socket; " - "socket.create_connection(('1.1.1.1', 80), timeout=2)", - ] - ) - assert r.exit_code != 0 - - _run(go()) - - @pytest.mark.skipif( - not (sys.platform == "darwin" and _has_sandbox_exec), - reason="sandbox-exec not available", - ) - def test_seatbelt_blocks_fs_escape(self, tmp_path): - write_root = tmp_path / "work" - write_root.mkdir() - outside = tmp_path / "forbidden" - - async def go(): - limits = ResourceLimits( - wall_seconds=5, - network="none", - fs_write_roots=(str(write_root),), - ) - async with SeatbeltSandbox(limits=limits) as s: - inside_path = write_root / "inside.txt" - r = await s.run( - ["/bin/sh", "-c", f"echo hi > {inside_path}"] - ) - assert r.exit_code == 0 - r2 = await s.run( - ["/bin/sh", "-c", f"echo nope > {outside}"] - ) - assert r2.exit_code != 0 - assert not outside.exists() - - _run(go()) - - -class TestAgenticRepl: - """REPL state, error capture, restart, timeout.""" - - def test_subprocess_repl_state_persists(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=10) - ) as s: - async with SubprocessRepl(s) as r: - r1 = await r.execute("x = 41") - assert r1.error is None - r2 = await r.execute("print(x + 1)") - assert r2.error is None - assert r2.stdout.strip() == "42" - - _run(go()) - - def test_subprocess_repl_captures_errors(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=10) - ) as s: - async with SubprocessRepl(s) as r: - res = await r.execute("1/0") - assert res.error is not None - assert res.error.ename == "ZeroDivisionError" - - _run(go()) - - def test_subprocess_repl_restart_clears_state(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=10) - ) as s: - async with SubprocessRepl(s) as r: - await r.execute("y = 99") - await r.restart() - res = await r.execute("print(y)") - assert res.error is not None # NameError - - _run(go()) - - def test_subprocess_repl_timeout(self): - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=10) - ) as s: - async with SubprocessRepl(s) as r: - res = await r.execute( - "import time; time.sleep(5)", timeout=0.3 - ) - assert res.timed_out - - _run(go()) - - @pytest.mark.skipif( - not _has_jupyter_client, reason="jupyter_client not installed" - ) - @pytest.mark.slow - def test_jupyter_repl_state_persists(self): - from torchrl.envs.llm.agentic.repl import JupyterRepl - - async def go(): - async with UnsafeSubprocessSandbox( - ResourceLimits(wall_seconds=60) - ) as s: - async with JupyterRepl(s) as r: - r1 = await r.execute("x = 41", timeout=30) - assert r1.error is None, r1 - r2 = await r.execute("print(x + 1)", timeout=30) - assert r2.error is None - assert r2.stdout.strip() == "42" - - _run(go()) diff --git a/torchrl/envs/llm/agentic/__init__.py b/torchrl/envs/llm/agentic/__init__.py index 84406ff8233..7bdea7108dc 100644 --- a/torchrl/envs/llm/agentic/__init__.py +++ b/torchrl/envs/llm/agentic/__init__.py @@ -14,17 +14,6 @@ See ``docs/source/reference/llms_envs.rst`` and ``docs/source/tutorials/llm_agentic.rst`` for a walkthrough. """ -# TODO: contributors please update as items are picked up. -# - streaming tool results (AsyncIterator[ToolEvent] from Tool.run) -# - per-tool token-budget accounting -# - E2B / Modal real implementations (stubs land first) -# - harmony parser (gpt-oss / o1-style) -# - Ray dispatcher (ToolCompose(parallel="ray")) -# - multimodal tool outputs (image / audio in ToolResult.parts) -# - structured-output validation against Tool.output_schema -# - per-tool retry / circuit breaker -# - tool-result caching (content-addressed) for replay -# - formal deprecation of legacy tool transforms once the new API soaks from __future__ import annotations from .protocols import ( @@ -33,13 +22,13 @@ JsonPart, ParsedCall, ParseResult, + TextPart, Tool, ToolCallParser, ToolContext, ToolError, ToolResult, ToolResultPart, - TextPart, ) from .schema import json_schema_from_pydantic, validate_args diff --git a/torchrl/envs/llm/agentic/parsers/anthropic.py b/torchrl/envs/llm/agentic/parsers/anthropic.py index d2036e36071..aca2c1ad272 100644 --- a/torchrl/envs/llm/agentic/parsers/anthropic.py +++ b/torchrl/envs/llm/agentic/parsers/anthropic.py @@ -21,7 +21,9 @@ class AnthropicToolUseParser: """Parses Anthropic-style ``tool_use`` content blocks. - Accepts either the full assistant message:: + Accepts either the full assistant message: + + .. code-block:: json { "role": "assistant", @@ -99,9 +101,7 @@ def render_call(self, call: ParsedCall) -> str: ensure_ascii=False, ) - def render_result( - self, call_id: str, result: ToolResult - ) -> Mapping[str, Any]: + def render_result(self, call_id: str, result: ToolResult) -> Mapping[str, Any]: return { "role": "user", "content": [ diff --git a/torchrl/envs/llm/agentic/parsers/json_block.py b/torchrl/envs/llm/agentic/parsers/json_block.py index c290b72dbfb..de040907854 100644 --- a/torchrl/envs/llm/agentic/parsers/json_block.py +++ b/torchrl/envs/llm/agentic/parsers/json_block.py @@ -16,7 +16,9 @@ class JSONToolCallParser: """Parses LLM responses formatted as a single JSON object. - Expected shape:: + The expected shape is: + + .. code-block:: json { "message": "Let me search.", @@ -72,9 +74,7 @@ def render_call(self, call: ParsedCall) -> str: ensure_ascii=False, ) - def render_result( - self, call_id: str, result: ToolResult - ) -> Mapping[str, Any]: + def render_result(self, call_id: str, result: ToolResult) -> Mapping[str, Any]: return { "role": "tool", "content": json.dumps( diff --git a/torchrl/envs/llm/agentic/parsers/openai.py b/torchrl/envs/llm/agentic/parsers/openai.py index a785903748c..d5367e9b1ca 100644 --- a/torchrl/envs/llm/agentic/parsers/openai.py +++ b/torchrl/envs/llm/agentic/parsers/openai.py @@ -22,17 +22,19 @@ class OpenAIToolCallParser: """Parses OpenAI-style ``tool_calls`` from an assistant message. - Accepts any of these shapes: + Accepts any of these shapes -- the full message dict: - - The full message dict:: + .. code-block:: json - {"role": "assistant", "content": "...", "tool_calls": [...]} + {"role": "assistant", "content": "...", "tool_calls": [...]} - - The choice dict:: + the choice dict: - {"message": {... "tool_calls": [...]}} + .. code-block:: json - - A bare list under ``tool_calls`` at the top level. + {"message": {"role": "assistant", "tool_calls": [...]}} + + or a bare list under ``tool_calls`` at the top level. Each call's ``id`` is preserved as :attr:`ParsedCall.call_id`. Arguments are JSON-decoded from the ``function.arguments`` string. @@ -111,9 +113,7 @@ def render_call(self, call: ParsedCall) -> str: ensure_ascii=False, ) - def render_result( - self, call_id: str, result: ToolResult - ) -> Mapping[str, Any]: + def render_result(self, call_id: str, result: ToolResult) -> Mapping[str, Any]: # OpenAI shape: a "tool" role message with tool_call_id correlation. return { "role": "tool", diff --git a/torchrl/envs/llm/agentic/parsers/xml.py b/torchrl/envs/llm/agentic/parsers/xml.py index e305b0f006b..16085de43fe 100644 --- a/torchrl/envs/llm/agentic/parsers/xml.py +++ b/torchrl/envs/llm/agentic/parsers/xml.py @@ -47,11 +47,7 @@ class XMLToolCallParser: ) def parse(self, response: str | Mapping[str, Any]) -> ParseResult: - text = ( - response - if isinstance(response, str) - else str(response.get("text", "")) - ) + text = response if isinstance(response, str) else str(response.get("text", "")) calls: list[ParsedCall] = [] def repl(m: re.Match) -> str: @@ -79,9 +75,7 @@ def render_call(self, call: ParsedCall) -> str: body = json.dumps(dict(call.args), ensure_ascii=False) return f'{body}' - def render_result( - self, call_id: str, result: ToolResult - ) -> Mapping[str, Any]: + def render_result(self, call_id: str, result: ToolResult) -> Mapping[str, Any]: body = result.text prefix = "[error] " if result.is_error else "" return { diff --git a/torchrl/envs/llm/agentic/protocols.py b/torchrl/envs/llm/agentic/protocols.py index 308e8d4a4a1..d720d24fb2d 100644 --- a/torchrl/envs/llm/agentic/protocols.py +++ b/torchrl/envs/llm/agentic/protocols.py @@ -18,77 +18,86 @@ Anthropic ``tool_use_id`` -- else a parser-assigned uuid4). Round-trips through :meth:`ToolCallParser.render_result` so downstream consumers can correlate calls and results. + +Value types (:class:`TextPart`, :class:`JsonPart`, :class:`ImagePart`, +:class:`FileRefPart`, :class:`ToolResult`, :class:`ParsedCall`, +:class:`ParseResult`, :class:`ToolContext`) are all +:class:`tensordict.TensorClass` subclasses so they stack across batch +dims and compose with TorchRL's batched envs and trajectories. """ from __future__ import annotations from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Any, ClassVar, Literal, Protocol, runtime_checkable +from typing import Any, ClassVar, Protocol, runtime_checkable -from tensordict import TensorDictBase +from tensordict import TensorClass, TensorDictBase # ----- result parts ----- -@dataclass(frozen=True, slots=True) -class TextPart: + +class TextPart(TensorClass["nocast"]): """A text fragment of a :class:`ToolResult`.""" text: str - kind: Literal["text"] = "text" + kind: str = "text" -@dataclass(frozen=True, slots=True) -class JsonPart: +class JsonPart(TensorClass["nocast"]): """A JSON-serialisable structured fragment of a :class:`ToolResult`.""" data: Any - kind: Literal["json"] = "json" + kind: str = "json" -@dataclass(frozen=True, slots=True) -class ImagePart: +class ImagePart(TensorClass["nocast"]): """An image fragment of a :class:`ToolResult` (raw bytes + media type).""" data: bytes media_type: str = "image/png" - kind: Literal["image"] = "image" + kind: str = "image" -@dataclass(frozen=True, slots=True) -class FileRefPart: +class FileRefPart(TensorClass["nocast"]): """A reference to a file produced by a tool (path inside the sandbox).""" path: str media_type: str | None = None - kind: Literal["file_ref"] = "file_ref" + kind: str = "file_ref" ToolResultPart = TextPart | JsonPart | ImagePart | FileRefPart -@dataclass(frozen=True, slots=True) -class ToolResult: +class ToolResult(TensorClass["nocast"]): """The output of a single :meth:`Tool.run` invocation. Attributes: - parts: Ordered tuple of result fragments. ``parts[0]`` is conventionally - text. Most call sites only need ``result.text``. - is_error: Whether the tool raised or otherwise produced an error. - ``parts[0]`` should describe the error when ``True``. - meta: Free-form metadata (timing, tokens used, raw provider payload). + parts: Ordered tuple of result fragments. ``parts[0]`` is + conventionally text. Most call sites only need + :attr:`text`. + is_error: Whether the tool raised or otherwise produced an + error. ``parts[0]`` should describe the error when ``True``. + meta: Free-form metadata (timing, tokens used, raw provider + payload). + + Stacks with :func:`tensordict.lazy_stack` so a batched env can + return one ``ToolResult`` per item without manual padding. """ - parts: tuple[ToolResultPart, ...] = () + parts: tuple = () is_error: bool = False - meta: Mapping[str, Any] = field(default_factory=dict) + meta: Mapping[str, Any] | None = None @property def text(self) -> str: - """Concatenation of all :class:`TextPart` and stringified - :class:`JsonPart` content. Convenience for the common case.""" + """Return the textual flattening of :attr:`parts`. + + Concatenates all :class:`TextPart` text and stringified + :class:`JsonPart` content. Convenience for the common case. + """ out: list[str] = [] - for p in self.parts: + for p in self.parts or (): if isinstance(p, TextPart): out.append(p.text) elif isinstance(p, JsonPart): @@ -113,38 +122,37 @@ def from_text( return cls( parts=(TextPart(text=text),), is_error=is_error, - meta=dict(meta or {}), + meta=dict(meta) if meta else None, ) -@dataclass class ToolError(Exception): """Raised by tools to signal a structured failure. Catching this in :class:`ToolCompose` produces a - :class:`ToolResult` with ``is_error=True``. Anything else surfaces as - an unstructured error (still wrapped, but flagged in ``meta``). + :class:`ToolResult` with ``is_error=True``. Anything else surfaces + as an unstructured error (still wrapped, but flagged in ``meta``). """ - message: str - detail: Mapping[str, Any] = field(default_factory=dict) - - def __str__(self) -> str: - return self.message + def __init__(self, message: str, detail: Mapping[str, Any] | None = None) -> None: + super().__init__(message) + self.message = message + self.detail = dict(detail) if detail else {} # ----- call / parse types ----- -@dataclass(frozen=True, slots=True) -class ParsedCall: + +class ParsedCall(TensorClass["nocast"]): """A single tool invocation parsed out of an LLM response. Attributes: tool: The name of the tool to invoke. args: Already-decoded keyword arguments. Validation against :attr:`Tool.input_schema` happens in :class:`ToolCompose`. - call_id: Stable identifier (parser-assigned if not present in the - source). Round-trips through :meth:`ToolCallParser.render_result`. + call_id: Stable identifier (parser-assigned if not present in + the source). Round-trips through + :meth:`ToolCallParser.render_result`. tag: Optional human-visible label (back-compat with ``ExecuteToolsInOrder``). """ @@ -155,38 +163,39 @@ class ParsedCall: tag: str | None = None -@dataclass(frozen=True, slots=True) -class ParseResult: +class ParseResult(TensorClass["nocast"]): """Output of :meth:`ToolCallParser.parse`. Attributes: - text: Cleaned message body with tool-call syntax stripped (when the - family embeds calls in the text -- XML, JSON-block). Empty for - providers where calls live in a structured field (OpenAI, - Anthropic). + text: Cleaned message body with tool-call syntax stripped (when + the family embeds calls in the text -- XML, JSON-block). + Empty for providers where calls live in a structured field + (OpenAI, Anthropic). calls: Calls in the order the model emitted them. raw: The original response, for round-trip and debugging. """ text: str - calls: tuple[ParsedCall, ...] + calls: tuple = () raw: Any = None # ----- context passed to a Tool ----- -@dataclass -class ToolContext: + +class ToolContext(TensorClass["nocast"]): """Per-call context handed to :meth:`Tool.run`. Attributes: - call_id: The :attr:`ParsedCall.call_id`. Stable across this turn. + call_id: The :attr:`ParsedCall.call_id`. Stable across this + turn. tag: Optional :attr:`ParsedCall.tag`. - state: Read-only filtered view of the env state. Only populated when - the owning :class:`ToolCompose` has ``pass_state_to_tools=True`` - *and* the tool has ``wants_state=True``. - sandbox: The compose-level sandbox, if any. Tools may also hold - their own sandbox by reference. + state: Read-only filtered view of the env state. Only + populated when the owning :class:`ToolCompose` has + ``pass_state_to_tools=True`` *and* the tool has + ``wants_state=True``. + sandbox: The compose-level sandbox, if any. Tools may also + hold their own sandbox by reference. repl: The compose-level REPL, if any. compose: Back-reference to the owning :class:`ToolCompose` for tool-to-tool dispatch from inside a tool body. @@ -202,17 +211,18 @@ class ToolContext: # ----- protocols ----- + @runtime_checkable class Tool(Protocol): """A unit invoked by name from an LLM response. - Subclasses (or duck-typed equivalents) declare ``name``, ``description``, - and ``input_schema`` (JSON Schema dict) at the class level, and implement - an async :meth:`run`. + Subclasses (or duck-typed equivalents) declare ``name``, + ``description``, and ``input_schema`` (JSON Schema dict) at the + class level, and implement an async :meth:`run`. - A tool may opt in to receiving env state via the ``wants_state`` class - attribute -- :class:`ToolCompose` will populate ``ctx.state`` when both - sides agree. + A tool may opt in to receiving env state via the ``wants_state`` + class attribute -- :class:`ToolCompose` will populate + ``ctx.state`` when both sides agree. Example: >>> from torchrl.envs.llm.agentic import Tool, ToolContext, ToolResult @@ -236,24 +246,29 @@ class Tool(Protocol): output_schema: ClassVar[Mapping[str, Any] | None] wants_state: ClassVar[bool] - async def run( - self, args: Mapping[str, Any], ctx: ToolContext - ) -> ToolResult: ... + async def run(self, args: Mapping[str, Any], ctx: ToolContext) -> ToolResult: + ... - async def setup(self) -> None: ... + async def setup(self) -> None: + ... - async def teardown(self) -> None: ... + async def teardown(self) -> None: + ... @runtime_checkable class ToolCallParser(Protocol): - """Parses an LLM response into :class:`ParsedCall` items and renders - results back into the family's message shape. + """Parses LLM responses and renders results in a provider format. + + A :class:`ToolCallParser` extracts :class:`ParsedCall` items from + an assistant message and renders results back into the family's + message shape (OpenAI, Anthropic, XML, JSON-block). Implementations must guarantee: 1. :meth:`parse` is pure and synchronous. - 2. Every returned :class:`ParsedCall` has a non-empty :attr:`call_id`. + 2. Every returned :class:`ParsedCall` has a non-empty + :attr:`call_id`. 3. ``parse -> render_call`` round-trips for calls produced by :meth:`parse` (within the same parser family). 4. :meth:`render_result` produces a mapping suitable for one new @@ -263,10 +278,11 @@ class ToolCallParser(Protocol): name: ClassVar[str] - def parse(self, response: str | Mapping[str, Any]) -> ParseResult: ... + def parse(self, response: str | Mapping[str, Any]) -> ParseResult: + ... - def render_call(self, call: ParsedCall) -> str: ... + def render_call(self, call: ParsedCall) -> str: + ... - def render_result( - self, call_id: str, result: ToolResult - ) -> Mapping[str, Any]: ... + def render_result(self, call_id: str, result: ToolResult) -> Mapping[str, Any]: + ... diff --git a/torchrl/envs/llm/agentic/repl/__init__.py b/torchrl/envs/llm/agentic/repl/__init__.py index e196695940d..371c76d237d 100644 --- a/torchrl/envs/llm/agentic/repl/__init__.py +++ b/torchrl/envs/llm/agentic/repl/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .base import Repl, ReplDisplay, ReplError, ReplResult -from .jupyter import JupyterRepl, _has_jupyter_client +from .jupyter import _has_jupyter_client, JupyterRepl from .subprocess import SubprocessRepl __all__ = [ @@ -22,4 +22,5 @@ "ReplError", "ReplResult", "SubprocessRepl", + "_has_jupyter_client", ] diff --git a/torchrl/envs/llm/agentic/repl/base.py b/torchrl/envs/llm/agentic/repl/base.py index 7071565013a..e1e44acb866 100644 --- a/torchrl/envs/llm/agentic/repl/base.py +++ b/torchrl/envs/llm/agentic/repl/base.py @@ -4,41 +4,47 @@ # LICENSE file in the root directory of this source tree. """REPL protocol and value types. -A :class:`Repl` runs stateful code inside a :class:`Sandbox`. State persists -across :meth:`execute` calls until :meth:`restart`. :meth:`interrupt` -preserves state but cancels the current execution. Timeouts surface as -``ReplResult.timed_out=True`` rather than raising. +A :class:`Repl` runs stateful code inside a :class:`Sandbox`. State +persists across :meth:`execute` calls until :meth:`restart`. +:meth:`interrupt` preserves state but cancels the current execution. +Timeouts surface as ``ReplResult.timed_out=True`` rather than raising. + +Value types (:class:`ReplDisplay`, :class:`ReplError`, +:class:`ReplResult`) are :class:`tensordict.TensorClass` subclasses so +they stack across batch dims and compose with batched envs. """ from __future__ import annotations -from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Any, ClassVar, Literal, Protocol, runtime_checkable +from typing import Any, ClassVar, Protocol, runtime_checkable + +from tensordict import TensorClass from ..sandbox.base import Sandbox -@dataclass(frozen=True, slots=True) -class ReplDisplay: - """A rich output (image, JSON, HTML) emitted via Jupyter's display - protocol. Subprocess REPLs emit nothing here. +class ReplDisplay(TensorClass["nocast"]): + """A rich output emitted via Jupyter's display protocol. + + Carries an image, JSON, or HTML payload. Subprocess REPLs emit + nothing here. """ media_type: str data: Any -@dataclass(frozen=True, slots=True) -class ReplError: - """Structured error from the kernel (exception name, value, traceback).""" +class ReplError(TensorClass["nocast"]): + """Structured error from the kernel. + + Captures the exception name, value, and traceback. + """ ename: str evalue: str traceback: str = "" -@dataclass(frozen=True, slots=True) -class ReplResult: +class ReplResult(TensorClass["nocast"]): """Outcome of one :meth:`Repl.execute` invocation. Attributes: @@ -47,12 +53,13 @@ class ReplResult: display: Rich outputs in emit order. error: Structured error, if any. timed_out: ``True`` if execution hit the timeout. - execution_count: Monotonic counter (Jupyter); ``-1`` for subprocess. + execution_count: Monotonic counter (Jupyter); ``-1`` for + subprocess. """ stdout: str = "" stderr: str = "" - display: tuple[ReplDisplay, ...] = () + display: tuple = () error: ReplError | None = None timed_out: bool = False execution_count: int = -1 @@ -65,7 +72,7 @@ def text(self) -> str: out.append(self.stdout) if self.stderr: out.append(self.stderr) - if self.error: + if self.error is not None: out.append(f"{self.error.ename}: {self.error.evalue}") return "\n".join(out).strip() @@ -74,35 +81,43 @@ def text(self) -> str: class Repl(Protocol): """Stateful code-execution session. - Lifecycle: ``open()`` is idempotent and required before ``execute()``; - ``close()`` releases the kernel. Use as ``async with repl:`` to bracket. + Lifecycle: ``open()`` is idempotent and required before + ``execute()``; ``close()`` releases the kernel. Use as + ``async with repl:`` to bracket. Invariants: - - :meth:`execute` is stateful (variables persist) until :meth:`restart`. + - :meth:`execute` is stateful (variables persist) until + :meth:`restart`. - :meth:`interrupt` does not lose state. - - :meth:`execute` never raises on user-code errors; errors surface in - :attr:`ReplResult.error`. Infrastructure failures raise. + - :meth:`execute` never raises on user-code errors; errors + surface in :attr:`ReplResult.error`. Infrastructure failures + raise. """ name: ClassVar[str] sandbox: Sandbox - async def open(self) -> None: ... + async def open(self) -> None: + ... - async def close(self) -> None: ... + async def close(self) -> None: + ... - async def __aenter__(self) -> Repl: ... + async def __aenter__(self) -> Repl: + ... - async def __aexit__(self, exc_type, exc, tb) -> None: ... + async def __aexit__(self, exc_type, exc, tb) -> None: + ... - async def execute( - self, code: str, *, timeout: float | None = None - ) -> ReplResult: ... + async def execute(self, code: str, *, timeout: float | None = None) -> ReplResult: + ... - async def interrupt(self) -> None: ... + async def interrupt(self) -> None: + ... - async def restart(self) -> None: ... + async def restart(self) -> None: + ... __all__ = ["Repl", "ReplDisplay", "ReplError", "ReplResult"] diff --git a/torchrl/envs/llm/agentic/repl/jupyter.py b/torchrl/envs/llm/agentic/repl/jupyter.py index 17e955ae318..595e1eaff8c 100644 --- a/torchrl/envs/llm/agentic/repl/jupyter.py +++ b/torchrl/envs/llm/agentic/repl/jupyter.py @@ -22,7 +22,7 @@ from torchrl._utils import logger as torchrl_logger from ..sandbox.base import Sandbox, SandboxError -from .base import Repl, ReplDisplay, ReplError, ReplResult +from .base import ReplDisplay, ReplError, ReplResult _has_jupyter_client = importlib.util.find_spec("jupyter_client") is not None @@ -119,9 +119,7 @@ async def __aenter__(self) -> JupyterRepl: async def __aexit__(self, exc_type, exc, tb) -> None: await self.close() - async def execute( - self, code: str, *, timeout: float | None = None - ) -> ReplResult: + async def execute(self, code: str, *, timeout: float | None = None) -> ReplResult: if self._kc is None: raise SandboxError("REPL is not open; call open() first") msg_id: str = self._kc.execute(code) @@ -164,7 +162,7 @@ async def execute( elif mtype == "status": if content.get("execution_state") == "idle": break - except asyncio.TimeoutError: + except TimeoutError: try: if self._km is not None: self._km.interrupt_kernel() diff --git a/torchrl/envs/llm/agentic/repl/subprocess.py b/torchrl/envs/llm/agentic/repl/subprocess.py index 3dbda7a7395..3f7f5608344 100644 --- a/torchrl/envs/llm/agentic/repl/subprocess.py +++ b/torchrl/envs/llm/agentic/repl/subprocess.py @@ -23,7 +23,7 @@ from typing import ClassVar from ..sandbox.base import Sandbox, SandboxError -from .base import Repl, ReplError, ReplResult +from .base import ReplError, ReplResult _BOOT = textwrap.dedent( @@ -120,7 +120,7 @@ async def close(self) -> None: self._proc.kill() try: await asyncio.wait_for(self._proc.wait(), timeout=2.0) - except asyncio.TimeoutError: # pragma: no cover + except TimeoutError: # pragma: no cover pass finally: self._proc = None @@ -132,9 +132,7 @@ async def __aenter__(self) -> SubprocessRepl: async def __aexit__(self, exc_type, exc, tb) -> None: await self.close() - async def execute( - self, code: str, *, timeout: float | None = None - ) -> ReplResult: + async def execute(self, code: str, *, timeout: float | None = None) -> ReplResult: if self._proc is None or self._proc.returncode is not None: raise SandboxError("REPL is not running; call open() first") async with self._lock: @@ -153,7 +151,6 @@ async def execute( stdout, stderr = await asyncio.wait_for( self._read_until_sentinels(sentinel), timeout=timeout ) - timed_out = False err: ReplError | None = None if stderr.endswith(f"{sentinel}_ERR\n"): body = stderr[: -len(f"{sentinel}_ERR\n")] @@ -172,7 +169,7 @@ async def execute( timed_out=False, execution_count=-1, ) - except asyncio.TimeoutError: + except TimeoutError: # Send SIGINT and let the boot loop recover. State is # preserved unless the user code is in an uninterruptible # syscall, in which case the user must call restart(). @@ -196,9 +193,7 @@ async def restart(self) -> None: await self.close() await self.open() - async def _read_until_sentinels( - self, sentinel: str - ) -> tuple[str, str]: + async def _read_until_sentinels(self, sentinel: str) -> tuple[str, str]: # Read stdout until "_END\n" appears, then drain stderr # until "_OK\n" or "_ERR\n" appears. assert self._proc is not None @@ -242,8 +237,10 @@ def _parse_traceback(tb: str) -> ReplError: async def _wrap_argv_via_sandbox(sandbox: Sandbox, argv: tuple[str, ...]) -> list[str]: - """Best-effort: ask the sandbox to compute the prefixed argv, fallback - to the raw argv if the backend doesn't support pre-wrapping. + """Ask the sandbox to compute the prefixed argv if it supports it. + + Falls back to the raw argv if the backend doesn't expose a + ``_build_argv`` hook. Best-effort. """ builder = getattr(sandbox, "_build_argv", None) if callable(builder): diff --git a/torchrl/envs/llm/agentic/sandbox/__init__.py b/torchrl/envs/llm/agentic/sandbox/__init__.py index ffef10b1fa2..5245eac1360 100644 --- a/torchrl/envs/llm/agentic/sandbox/__init__.py +++ b/torchrl/envs/llm/agentic/sandbox/__init__.py @@ -10,7 +10,6 @@ """ from __future__ import annotations -import shutil import sys import warnings @@ -18,8 +17,8 @@ from .docker import DockerSandbox from .e2b import E2BSandbox from .modal import ModalSandbox -from .subprocess_bwrap import BubblewrapSandbox, _has_bwrap -from .subprocess_seatbelt import SeatbeltSandbox, _has_sandbox_exec +from .subprocess_bwrap import _has_bwrap, BubblewrapSandbox +from .subprocess_seatbelt import _has_sandbox_exec, SeatbeltSandbox from .unsafe import UnsafeSubprocessSandbox diff --git a/torchrl/envs/llm/agentic/sandbox/base.py b/torchrl/envs/llm/agentic/sandbox/base.py index c35e3d0aed9..7d222544651 100644 --- a/torchrl/envs/llm/agentic/sandbox/base.py +++ b/torchrl/envs/llm/agentic/sandbox/base.py @@ -4,72 +4,82 @@ # LICENSE file in the root directory of this source tree. """Sandbox protocol and value types. -A :class:`Sandbox` is an async context manager owning an isolated execution -environment. :meth:`Sandbox.run` launches a subprocess inside it, -:meth:`Sandbox.write_file` and :meth:`Sandbox.read_file` mediate I/O. The -default backends -- :class:`BubblewrapSandbox` (Linux) and -:class:`SeatbeltSandbox` (macOS) -- enforce filesystem and network isolation -via OS-bundled tools. +A :class:`Sandbox` is an async context manager owning an isolated +execution environment. :meth:`Sandbox.run` launches a subprocess inside +it, :meth:`Sandbox.write_file` and :meth:`Sandbox.read_file` mediate +I/O. The default backends -- :class:`BubblewrapSandbox` (Linux) and +:class:`SeatbeltSandbox` (macOS) -- enforce filesystem and network +isolation via OS-bundled tools. For environments where neither is available, -:class:`UnsafeSubprocessSandbox` provides a no-op fallback that runs a bare -subprocess with no isolation. It emits a ``UserWarning`` on every +:class:`UnsafeSubprocessSandbox` provides a no-op fallback that runs a +bare subprocess with no isolation. It emits a ``UserWarning`` on every :meth:`open` call so the lack of containment is impossible to miss. + +Value types (:class:`ResourceLimits`, :class:`SandboxResult`) are +:class:`tensordict.TensorClass` subclasses so they stack across batch +dims and compose with TorchRL's batched envs. """ from __future__ import annotations from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any, ClassVar, Literal, Protocol, runtime_checkable +from typing import ClassVar, Literal, Protocol, runtime_checkable + +from tensordict import TensorClass _NetworkPolicy = Literal["none", "loopback", "allowlist", "full"] class SandboxError(RuntimeError): - """Raised on sandbox infrastructure failures (launch, kernel error, etc.). + """Raised on sandbox infrastructure failures. - Tool processes that exit non-zero do *not* raise; the non-zero status is - surfaced via :attr:`SandboxResult.exit_code`. + Covers launch failures, kernel errors, etc. Tool processes that + exit non-zero do *not* raise; the non-zero status is surfaced via + :attr:`SandboxResult.exit_code`. """ -@dataclass(frozen=True, slots=True) -class ResourceLimits: +class ResourceLimits(TensorClass["nocast"]): """Per-sandbox or per-call resource limits. Attributes: cpu_seconds: Soft CPU budget. ``None`` means unlimited. wall_seconds: Wall-clock timeout. ``None`` means unlimited. memory_bytes: Address-space cap. ``None`` means unlimited. - network: Policy for outbound network. ``"none"`` blocks all sockets, - ``"loopback"`` allows 127.0.0.0/8 only, ``"allowlist"`` consults - :attr:`network_allowlist`, ``"full"`` is unrestricted. + network: Policy for outbound network. ``"none"`` blocks all + sockets, ``"loopback"`` allows 127.0.0.0/8 only, + ``"allowlist"`` consults :attr:`network_allowlist`, + ``"full"`` is unrestricted. network_allowlist: ``host:port`` strings, used only when ``network == "allowlist"``. - fs_read_roots: Absolute paths the sandbox may read from. Empty means - backend default (typically ``/`` read-only on Linux/macOS). - fs_write_roots: Absolute paths the sandbox may write to. Empty means - no writes allowed. - max_processes: Cap on concurrent subprocesses. ``None`` for unlimited. - env: Environment-variable allowlist. ``None`` means a clean env with - only ``PATH``, ``HOME``, ``LANG``. + fs_read_roots: Absolute paths the sandbox may read from. + Empty means backend default (typically ``/`` read-only on + Linux/macOS). + fs_write_roots: Absolute paths the sandbox may write to. + Empty means no writes allowed. + max_processes: Cap on concurrent subprocesses. ``None`` for + unlimited. + env: Environment-variable allowlist. ``None`` means a clean + env with only ``PATH``, ``HOME``, ``LANG``. """ cpu_seconds: float | None = 30.0 wall_seconds: float | None = 60.0 memory_bytes: int | None = 512 * 1024 * 1024 - network: _NetworkPolicy = "none" - network_allowlist: tuple[str, ...] = () - fs_read_roots: tuple[str, ...] = () - fs_write_roots: tuple[str, ...] = () + network: str = "none" + network_allowlist: tuple = () + fs_read_roots: tuple = () + fs_write_roots: tuple = () max_processes: int | None = 32 env: Mapping[str, str] | None = None def narrow(self, other: ResourceLimits | None) -> ResourceLimits: - """Return a new :class:`ResourceLimits` that is at most as permissive - as both ``self`` and ``other``. Used by :meth:`Sandbox.run` to apply a - per-call override that may only narrow the construction limits. + """Return the tightest combination of ``self`` and ``other``. + + The result is at most as permissive as either input. Used by + :meth:`Sandbox.run` to apply a per-call override that may only + narrow the construction limits. """ if other is None: return self @@ -81,41 +91,47 @@ def _min_or(a: float | None, b: float | None) -> float | None: return a return min(a, b) - # Tighten network: choose the strictest. - rank: dict[_NetworkPolicy, int] = { + rank: dict[str, int] = { "none": 0, "loopback": 1, "allowlist": 2, "full": 3, } - net = self.network if rank[self.network] <= rank[other.network] else other.network + net = ( + self.network if rank[self.network] <= rank[other.network] else other.network + ) + if self.network_allowlist or other.network_allowlist: + allow = tuple( + sorted(set(self.network_allowlist) & set(other.network_allowlist)) + ) + else: + allow = () + if self.fs_read_roots and other.fs_read_roots: + read_roots = tuple( + sorted(set(self.fs_read_roots) & set(other.fs_read_roots)) + ) + else: + read_roots = self.fs_read_roots or other.fs_read_roots + if self.fs_write_roots and other.fs_write_roots: + write_roots = tuple( + sorted(set(self.fs_write_roots) & set(other.fs_write_roots)) + ) + else: + write_roots = self.fs_write_roots or other.fs_write_roots return ResourceLimits( cpu_seconds=_min_or(self.cpu_seconds, other.cpu_seconds), wall_seconds=_min_or(self.wall_seconds, other.wall_seconds), memory_bytes=_min_or(self.memory_bytes, other.memory_bytes), network=net, - network_allowlist=( - tuple(set(self.network_allowlist) & set(other.network_allowlist)) - if self.network_allowlist or other.network_allowlist - else () - ), - fs_read_roots=tuple( - sorted(set(self.fs_read_roots) & set(other.fs_read_roots)) - ) - if self.fs_read_roots and other.fs_read_roots - else (self.fs_read_roots or other.fs_read_roots), - fs_write_roots=tuple( - sorted(set(self.fs_write_roots) & set(other.fs_write_roots)) - ) - if self.fs_write_roots and other.fs_write_roots - else (self.fs_write_roots or other.fs_write_roots), + network_allowlist=allow, + fs_read_roots=read_roots, + fs_write_roots=write_roots, max_processes=_min_or(self.max_processes, other.max_processes), env=other.env if other.env is not None else self.env, ) -@dataclass(frozen=True, slots=True) -class SandboxResult: +class SandboxResult(TensorClass["nocast"]): """Outcome of a single :meth:`Sandbox.run` invocation. Attributes: @@ -125,10 +141,12 @@ class SandboxResult: wall_seconds: Observed wall-clock duration. timed_out: ``True`` if the subprocess hit :attr:`ResourceLimits.wall_seconds` before exiting. - truncated: ``True`` if stdout/stderr were truncated by an output cap. + truncated: ``True`` if stdout/stderr were truncated by an + output cap. artifacts: File contents emitted under - :attr:`ResourceLimits.fs_write_roots`, keyed by relative path. - Populated lazily by backends that support it; default empty. + :attr:`ResourceLimits.fs_write_roots`, keyed by relative + path. Populated lazily by backends that support it; + default empty. """ stdout: str @@ -137,36 +155,41 @@ class SandboxResult: wall_seconds: float timed_out: bool = False truncated: bool = False - artifacts: Mapping[str, bytes] = field(default_factory=dict) + artifacts: Mapping[str, bytes] | None = None @runtime_checkable class Sandbox(Protocol): - """An async context manager owning an isolated execution environment. + """Async context manager owning an isolated execution environment. - Lifecycle: ``open()`` is idempotent and required before ``run()``; - ``close()`` releases all OS resources. Use as ``async with sandbox:`` to - bracket lifecycle automatically. + Lifecycle: ``open()`` is idempotent and required before + ``run()``; ``close()`` releases all OS resources. Use as + ``async with sandbox:`` to bracket lifecycle automatically. :meth:`run` does *not* raise on tool exit codes. It raises - :class:`SandboxError` only on infrastructure failures (sandbox launch, - host kernel error). Per-call ``limits`` may only narrow construction - ``limits``; widening attempts are silently clamped. + :class:`SandboxError` only on infrastructure failures (sandbox + launch, host kernel error). Per-call ``limits`` may only narrow + construction ``limits``; widening attempts are silently clamped. - All paths in :meth:`write_file` / :meth:`read_file` are sandbox-virtual; - the backend is responsible for translating to host paths. + All paths in :meth:`write_file` / :meth:`read_file` are + sandbox-virtual; the backend is responsible for translating to + host paths. """ name: ClassVar[str] limits: ResourceLimits - async def open(self) -> None: ... + async def open(self) -> None: + ... - async def close(self) -> None: ... + async def close(self) -> None: + ... - async def __aenter__(self) -> Sandbox: ... + async def __aenter__(self) -> Sandbox: + ... - async def __aexit__(self, exc_type, exc, tb) -> None: ... + async def __aexit__(self, exc_type, exc, tb) -> None: + ... async def run( self, @@ -175,13 +198,14 @@ async def run( stdin: bytes | None = None, cwd: str | None = None, limits: ResourceLimits | None = None, - ) -> SandboxResult: ... + ) -> SandboxResult: + ... - async def write_file(self, path: str, data: bytes) -> None: ... + async def write_file(self, path: str, data: bytes) -> None: + ... - async def read_file( - self, path: str, max_bytes: int | None = None - ) -> bytes: ... + async def read_file(self, path: str, max_bytes: int | None = None) -> bytes: + ... __all__ = [ diff --git a/torchrl/envs/llm/agentic/sandbox/docker.py b/torchrl/envs/llm/agentic/sandbox/docker.py index c7b025ed318..40c64bd6439 100644 --- a/torchrl/envs/llm/agentic/sandbox/docker.py +++ b/torchrl/envs/llm/agentic/sandbox/docker.py @@ -14,7 +14,7 @@ from collections.abc import Sequence from typing import ClassVar -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxResult class DockerSandbox: @@ -28,7 +28,7 @@ def __init__( *, image: str = "python:3.11-slim", ) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() self.image = image async def open(self) -> None: diff --git a/torchrl/envs/llm/agentic/sandbox/e2b.py b/torchrl/envs/llm/agentic/sandbox/e2b.py index 7251d72c523..3780b7b2b49 100644 --- a/torchrl/envs/llm/agentic/sandbox/e2b.py +++ b/torchrl/envs/llm/agentic/sandbox/e2b.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from typing import ClassVar -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxResult _has_e2b = importlib.util.find_spec("e2b") is not None @@ -24,7 +24,7 @@ class E2BSandbox: name: ClassVar[str] = "e2b" def __init__(self, limits: ResourceLimits | None = None) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() async def open(self) -> None: raise NotImplementedError( diff --git a/torchrl/envs/llm/agentic/sandbox/modal.py b/torchrl/envs/llm/agentic/sandbox/modal.py index 8dc9fc7435e..24d4847b45f 100644 --- a/torchrl/envs/llm/agentic/sandbox/modal.py +++ b/torchrl/envs/llm/agentic/sandbox/modal.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from typing import ClassVar -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxResult _has_modal = importlib.util.find_spec("modal") is not None @@ -24,7 +24,7 @@ class ModalSandbox: name: ClassVar[str] = "modal" def __init__(self, limits: ResourceLimits | None = None) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() async def open(self) -> None: raise NotImplementedError( diff --git a/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py b/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py index 2b3014abbc8..78447c52ffc 100644 --- a/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py +++ b/torchrl/envs/llm/agentic/sandbox/subprocess_bwrap.py @@ -22,7 +22,6 @@ from __future__ import annotations import asyncio -import importlib.util import os import shutil import time @@ -30,7 +29,7 @@ from pathlib import Path from typing import ClassVar -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxError, SandboxResult _OUTPUT_CAP = 1 << 20 @@ -71,7 +70,7 @@ def __init__( *, bwrap_path: str | None = None, ) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() self._bwrap = bwrap_path or shutil.which("bwrap") self._opened = False @@ -167,7 +166,7 @@ async def run( proc.communicate(stdin), timeout=eff.wall_seconds ) timed_out = False - except asyncio.TimeoutError: + except TimeoutError: proc.kill() try: out_b, err_b = await proc.communicate() @@ -200,9 +199,7 @@ async def write_file(self, path: str, data: bytes) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) Path(path).write_bytes(data) - async def read_file( - self, path: str, max_bytes: int | None = None - ) -> bytes: + async def read_file(self, path: str, max_bytes: int | None = None) -> bytes: if not self._opened: raise SandboxError("sandbox is not open; call open() first") b = Path(path).read_bytes() diff --git a/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py b/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py index 62a66c7e6d4..8cf233534b3 100644 --- a/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py +++ b/torchrl/envs/llm/agentic/sandbox/subprocess_seatbelt.py @@ -25,7 +25,7 @@ from pathlib import Path from typing import ClassVar -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxError, SandboxResult _OUTPUT_CAP = 1 << 20 @@ -71,7 +71,7 @@ class SeatbeltSandbox: name: ClassVar[str] = "seatbelt" def __init__(self, limits: ResourceLimits | None = None) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() self._exec = shutil.which("sandbox-exec") self._opened = False @@ -139,7 +139,7 @@ async def run( proc.communicate(stdin), timeout=eff.wall_seconds ) timed_out = False - except asyncio.TimeoutError: + except TimeoutError: proc.kill() try: out_b, err_b = await proc.communicate() @@ -170,9 +170,7 @@ async def write_file(self, path: str, data: bytes) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) Path(path).write_bytes(data) - async def read_file( - self, path: str, max_bytes: int | None = None - ) -> bytes: + async def read_file(self, path: str, max_bytes: int | None = None) -> bytes: if not self._opened: raise SandboxError("sandbox is not open; call open() first") b = Path(path).read_bytes() diff --git a/torchrl/envs/llm/agentic/sandbox/unsafe.py b/torchrl/envs/llm/agentic/sandbox/unsafe.py index 883111c9ba6..9af25dfbf04 100644 --- a/torchrl/envs/llm/agentic/sandbox/unsafe.py +++ b/torchrl/envs/llm/agentic/sandbox/unsafe.py @@ -21,7 +21,7 @@ from torchrl._utils import logger as torchrl_logger -from .base import ResourceLimits, Sandbox, SandboxError, SandboxResult +from .base import ResourceLimits, SandboxError, SandboxResult _OUTPUT_CAP = 1 << 20 # 1 MiB per stream @@ -51,7 +51,7 @@ class UnsafeSubprocessSandbox: name: ClassVar[str] = "unsafe-subprocess" def __init__(self, limits: ResourceLimits | None = None) -> None: - self.limits = limits or ResourceLimits() + self.limits = limits if limits is not None else ResourceLimits() self._opened = False async def open(self) -> None: @@ -116,7 +116,7 @@ async def run( timeout=eff.wall_seconds, ) timed_out = False - except asyncio.TimeoutError: + except TimeoutError: proc.kill() try: out_b, err_b = await proc.communicate() @@ -145,9 +145,7 @@ async def write_file(self, path: str, data: bytes) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) Path(path).write_bytes(data) - async def read_file( - self, path: str, max_bytes: int | None = None - ) -> bytes: + async def read_file(self, path: str, max_bytes: int | None = None) -> bytes: if not self._opened: raise SandboxError("sandbox is not open; call open() first") b = Path(path).read_bytes() diff --git a/torchrl/envs/llm/agentic/schema.py b/torchrl/envs/llm/agentic/schema.py index 1a92c3f71f8..dcd1fdf532a 100644 --- a/torchrl/envs/llm/agentic/schema.py +++ b/torchrl/envs/llm/agentic/schema.py @@ -34,9 +34,7 @@ class SchemaValidationError(ValueError): """Raised by :func:`validate_args` on a schema mismatch.""" -def validate_args( - args: Mapping[str, Any], schema: Mapping[str, Any] | None -) -> None: +def validate_args(args: Mapping[str, Any], schema: Mapping[str, Any] | None) -> None: """Validate ``args`` against a JSON Schema dict. Implements the subset that matters for tool-call dispatch: @@ -97,6 +95,4 @@ def json_schema_from_pydantic(model: Any) -> dict[str, Any]: ) if hasattr(model, "model_json_schema"): return model.model_json_schema() - raise TypeError( - f"{model!r} is not a pydantic v2 BaseModel subclass." - ) + raise TypeError(f"{model!r} is not a pydantic v2 BaseModel subclass.")