Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions benchmarks/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,99 @@ def setup():
warmup_rounds=3,
setup=setup,
)


# ----- Agentic ToolCompose dispatch benchmarks -----


class _SleepyTool:
"""Bench-only tool: simulates a network/I/O call via asyncio.sleep."""

description = "sleep"
input_schema = {
"type": "object",
"properties": {"ms": {"type": "integer"}},
}
output_schema = None
wants_state = False

def __init__(self, name: str) -> None:
self.name = name

async def setup(self) -> None:
pass

async def teardown(self) -> None:
pass

async def run(self, args, ctx):
import asyncio as _asyncio

from torchrl.envs.llm.agentic import ToolResult

await _asyncio.sleep(args.get("ms", 100) / 1000)
return ToolResult.from_text("ok")


@pytest.mark.benchmark(group="agentic-dispatch")
@pytest.mark.parametrize("n_tools", [3, 8])
def test_toolcompose_parallel_dispatch(benchmark, n_tools):
"""Bench parallel ToolCompose dispatch.

With ``n_tools`` async tools each sleeping 50ms, parallel dispatch
should bottom out near 50ms regardless of ``n_tools``; serial
dispatch would scale linearly.
"""
from torchrl.envs import TransformedEnv
from torchrl.envs.llm import ChatEnv
from torchrl.envs.llm.agentic import ToolCompose
from torchrl.envs.llm.agentic.parsers import XMLToolCallParser

set_list_to_stack(True).set()
base = ChatEnv(batch_size=(1,), input_mode="history")
tools = [_SleepyTool(f"t{i}") for i in range(n_tools)]
env = TransformedEnv(
base, ToolCompose(tools=tools, parser=XMLToolCallParser())
)

fake = "".join(
f'<tool name="t{i}" tag="{i}">{{"ms": 50}}</tool>'
for i in range(n_tools)
)

def go():
obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,)))
obs["history"].full = obs["history"].prompt.extend(
History(role="assistant", content=fake).view(1, 1), dim=-1
)
env.step(obs)

benchmark(go)


@pytest.mark.benchmark(group="agentic-dispatch")
def test_toolcompose_single_call_baseline(benchmark):
"""One-call baseline so the n=3 / n=8 numbers are interpretable."""
from torchrl.envs import TransformedEnv
from torchrl.envs.llm import ChatEnv
from torchrl.envs.llm.agentic import ToolCompose
from torchrl.envs.llm.agentic.parsers import XMLToolCallParser

set_list_to_stack(True).set()
base = ChatEnv(batch_size=(1,), input_mode="history")
env = TransformedEnv(
base,
ToolCompose(
tools=[_SleepyTool("t0")], parser=XMLToolCallParser()
),
)
fake = '<tool name="t0" tag="0">{"ms": 50}</tool>'

def go():
obs = env.reset(TensorDict({"query": "go"}, batch_size=(1,)))
obs["history"].full = obs["history"].prompt.extend(
History(role="assistant", content=fake).view(1, 1), dim=-1
)
env.step(obs)

benchmark(go)
Loading
Loading