diff --git a/experiments/run_miniwob.py b/experiments/run_miniwob.py new file mode 100644 index 00000000..8af60cdb --- /dev/null +++ b/experiments/run_miniwob.py @@ -0,0 +1,81 @@ +import argparse +import logging +import os + +from bgym import DEFAULT_BENCHMARKS +from dotenv import load_dotenv + +from agentlab.agents.generic_agent.agent_configs import GPT5_MINI_FLAGS +from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs +from agentlab.agents.react_toolcall_agent import AgentConfig, LLMArgs, ReactToolCallAgentArgs +from agentlab.backends.browser.mcp_playwright import MCPPlaywright +from agentlab.backends.browser.playwright import SyncPlaywright +from agentlab.benchmarks.miniwob import MiniWobBenchmark +from agentlab.experiments.study import make_study +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s" +logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()]) +logger = logging.getLogger(__name__) +load_dotenv() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run MiniWob benchmark experiments") + parser.add_argument( + "--backend", + choices=["playwright", "mcp", "bgym"], + default="playwright", + help="Browser backend to use (default: playwright)", + ) + parser.add_argument( + "--agent", + choices=["tape", "generic", "react"], + default="react", + help="Agent type to use (default: react)", + ) + parser.add_argument( + "--config", + type=str, + default="miniwob", + help="Hydra config name to load (default: miniwob)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + if args.backend == "bgym": + benchmark = DEFAULT_BENCHMARKS["miniwob"](n_repeats=1) + elif args.backend == "playwright": + benchmark = MiniWobBenchmark(backend_cls=SyncPlaywright) + elif args.backend == "mcp": + benchmark = MiniWobBenchmark(backend_cls=MCPPlaywright) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if args.agent == "generic": + agent_args = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-5-mini-2025-08-07"], + flags=GPT5_MINI_FLAGS, + ) + else: # react + agent_args = ReactToolCallAgentArgs( + llm_args=LLMArgs( + model_name="azure/gpt-5-mini", temperature=1.0, max_total_tokens=128000 + ), + config=AgentConfig(), + ) + + study = make_study( + benchmark=benchmark, + agent_args=agent_args, + logging_level=logging.INFO, + logging_level_stdout=logging.INFO, + ) + if os.environ.get("AGENTLAB_DEBUG"): + study.exp_args_list = study.exp_args_list[23:27] + study.run(n_jobs=1, n_relaunch=1, parallel_backend="sequential") + else: + study.run(n_jobs=8, n_relaunch=1, parallel_backend="ray") diff --git a/src/agentlab/actions.py b/src/agentlab/actions.py new file mode 100644 index 00000000..a0dd8d10 --- /dev/null +++ b/src/agentlab/actions.py @@ -0,0 +1,125 @@ +import json +import logging +from typing import Callable, Literal +from uuid import uuid4 + +from bgym import AbstractActionSet +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel, Field + +from agentlab.llm.llm_utils import parse_html_tags_raise + +logger = logging.getLogger(__name__) + + +class FunctionSpec(BaseModel): + """ + A class representing the specification of a function. + + Attributes: + name (str): The name of the function. + description (str): A brief description of the function. + parameters (dict): A dictionary containing the parameters of the function. + """ + + name: str + description: str + parameters: dict + + +class ToolCall(BaseModel): + id: str = Field(default_factory=lambda: uuid4().hex) + name: str + arguments: dict = Field(default_factory=dict) + + def llm_view(self, **kwargs) -> str: + return self.model_dump_json(indent=2) + + +class ToolSpec(BaseModel): + """ + ToolSpec is a model that represents a tool specification with a type and a function. + + Attributes: + type (Literal["function"]): The type of the tool, which is always "function". + function (FunctionSpec): The specification of the function. + """ + + type: Literal["function"] = "function" + function: FunctionSpec + + def description(self) -> str: + return f"{self.function.name} - {self.function.description}" + + @classmethod + def from_function(cls, function: Callable): + """ + Creates an instance of the class by validating the model from a given function. + + Args: + function (Callable): The function to be converted and validated. + + Returns: + (ToolSpec): An instance of the class with the validated model. + """ + return cls.model_validate(convert_to_openai_tool(function)) + + +class ToolsActionSet(AbstractActionSet): + multiaction: bool = False + strict: bool = False + + def __init__(self, actions: list[ToolSpec]): + self.actions = actions + + def describe(self, with_long_description: bool = True, with_examples: bool = True) -> str: + descs = [] + for action in self.actions: + desc = f"## {action.description()}.\n Schema: {action.model_dump_json(indent=2)}" + descs.append(desc) + tools_description = "\n".join(descs) + return tools_description + + def example_action(self, abstract: bool) -> str: + if abstract: + return """{ + "name": "", + "arguments": { + "": "", + "": "", + ... + } +}""" + else: + return """{ + "name": "browser_click", + "arguments": { + "element": "buttom with year 2022", + "ref": "e26" + } +}""" + + @classmethod + def parse_action(cls, llm_output: str) -> ToolCall: + logger.info(f"Parsing action: {llm_output}") + if "" in llm_output: + content_dict, valid, retry_message = parse_html_tags_raise(llm_output, keys=["action"]) + if not valid or "action" not in content_dict: + raise ValueError( + f"Invalid action: llm_output: {llm_output}, retry_message: {retry_message}" + ) + action_str = content_dict["action"] + else: + action_str = llm_output + try: + action_dict = json.loads(action_str) + except json.JSONDecodeError: + raise ValueError(f"Failed to parse action: {action_str}") + return ToolCall(name=action_dict["name"], arguments=action_dict["arguments"]) + + def to_python_code(self, action) -> str: + return action + + def tools(self) -> list[dict]: + """Returns the list of tool spec dicts for LLM consumption.""" + return [tool.model_dump() for tool in self.actions] diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index d1f48f76..f65b2132 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -10,13 +10,12 @@ from copy import deepcopy from dataclasses import asdict, dataclass -from functools import partial from warnings import warn -import bgym from bgym import Benchmark from browsergym.experiments.agent import Agent, AgentInfo +from agentlab.actions import ToolsActionSet from agentlab.agents import dynamic_prompting as dp from agentlab.agents.agent_args import AgentArgs from agentlab.llm.chat_api import BaseModelArgs @@ -65,9 +64,12 @@ def prepare(self): def close(self): return self.chat_model_args.close_server() - def make_agent(self): + def make_agent(self, actions: list | None = None): return GenericAgent( - chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry + chat_model_args=self.chat_model_args, + flags=self.flags, + max_retry=self.max_retry, + actions=actions, ) @@ -78,6 +80,7 @@ def __init__( chat_model_args: BaseModelArgs, flags: GenericPromptFlags, max_retry: int = 4, + actions: list | None = None, ): self.chat_llm = chat_model_args.make_model() @@ -85,8 +88,13 @@ def __init__( self.max_retry = max_retry self.flags = flags - self.action_set = self.flags.action.action_set.make_action_set() - self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + if actions is not None: + self.action_set = ToolsActionSet(actions=actions) + self.flags.action.action_set = self.action_set + self._obs_preprocessor = lambda obs: obs + else: + self.action_set = self.flags.action.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) self._check_flag_constancy() self.reset(seed=None) @@ -157,7 +165,7 @@ def get_action(self, obs): stats=stats, extra_info={"chat_model_args": asdict(self.chat_model_args)}, ) - return ans_dict["action"], agent_info + return ans_dict["action"], asdict(agent_info) def reset(self, seed=None): self.seed = seed diff --git a/src/agentlab/agents/react_toolcall_agent.py b/src/agentlab/agents/react_toolcall_agent.py new file mode 100644 index 00000000..01df0836 --- /dev/null +++ b/src/agentlab/agents/react_toolcall_agent.py @@ -0,0 +1,247 @@ +import json +import logging +import pprint +from dataclasses import dataclass +from functools import partial +from typing import Callable, Literal + +import numpy as np +from litellm import completion +from litellm.types.utils import Message, ModelResponse +from litellm.utils import token_counter +from PIL import Image +from termcolor import colored + +from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.chat_api import BaseModelArgs +from agentlab.llm.llm_utils import image_to_png_base64_url + +logger = logging.getLogger(__name__) + + +class LLMArgs(BaseModelArgs): + reasoning_effort: Literal["minimal", "low", "medium", "high"] = "low" + num_retries: int = 3 + + def make_model(self) -> Callable: + return partial( + completion, + model=self.model_name, + temperature=self.temperature, + max_tokens=self.max_total_tokens, + max_completion_tokens=self.max_new_tokens, + reasoning_effort=self.reasoning_effort, + num_retries=self.num_retries, + tool_choice="auto", + parallel_tool_calls=False, + ) + + +@dataclass +class AgentConfig: + use_html: bool = True + use_axtree: bool = False + use_screenshot: bool = True + max_actions: int = 10 + max_obs_chars: int = 100000 # truncate long observations to N chars + max_history_tokens: int = 120000 + system_prompt: str = """ +You are an expert AI Agent trained to assist users with complex web tasks. +Your role is to understand the goal, perform actions until the goal is accomplished and respond in a helpful and accurate manner. +Keep your replies brief, concise, direct and on topic. Prioritize clarity and avoid over-elaboration. +Do not express emotions or opinions.""" + guidance: str = """ +Think along the following lines: +1. Summarize the last observation and describe the visible changes in the state. +2. Evaluate action success, explain impact on task and next steps. +3. If you see any errors in the last observation, think about it. If there is no error, just move on. +4. List next steps to move towards the goal and propose next immediate action. +Then produce the single function call that performs the proposed action. If the task is complete, produce the final step.""" + summarize_system_prompt: str = """ +You are a helpful assistant that summarizes agent interaction history. Following messages is the history to summarize:""" + summarize_prompt: str = """ +Summarize the presented agent interaction history concisely. +Focus on: +- The original goal +- Key actions taken and their outcomes +- Important errors or obstacles encountered +- Current progress toward the goal +Provide a concise summary that preserves all information needed to continue the task.""" + + +def user_message(content: str | list[dict]) -> dict: + return {"role": "user", "content": content} + + +class ReactToolCallAgent: + def __init__( + self, + action_set: ToolsActionSet, + llm: Callable[..., ModelResponse], + token_counter: Callable[..., int], + config: AgentConfig, + ): + self.action_set = action_set + self.tools = self.action_set.tools() + self.history: list[dict | Message] = [{"role": "system", "content": config.system_prompt}] + self.llm = llm + self.token_counter = token_counter + self.config = config + self.last_tool_call_id: str = "" + + def obs_preprocessor(self, obs: dict) -> dict: + return obs + + def obs_to_messages(self, obs: dict) -> list[dict]: + """ + Convert the observation dictionary into a list of chat messages for Lite LLM + """ + goal_obj = obs.pop("goal_object", None) + if not self.config.use_html: + obs.pop("pruned_html", None) + obs.pop("html", None) + if not self.config.use_axtree: + obs.pop("axtree_txt", None) + if not self.config.use_screenshot: + obs.pop("screenshot", None) + images = {k: v for k, v in obs.items() if isinstance(v, (Image.Image, np.ndarray))} + texts = {k: v for k, v in obs.items() if v is not None and isinstance(v, str) and v != ""} + messages = [] + + if not self.last_tool_call_id and goal_obj is not None and len(goal_obj) > 0 and "text" in goal_obj[0]: + # its a first observation when there are no tool_call_id, so include goal + goal = goal_obj[0]["text"] + messages.append(user_message(f"Goal: {goal}")) + + text = "\n\n".join([f"## {k}\n{v}" for k, v in texts.items()])[:self.config.max_obs_chars] + if self.last_tool_call_id: + message = { + "role": "tool", + "tool_call_id": self.last_tool_call_id, + "content": text, + } + else: + message = user_message(text) + messages.append(message) + + if self.config.use_screenshot: + for caption, image in images.items(): + image_content = [ + {"type": "text", "text": caption}, + {"type": "image_url", "image_url": {"url": image_to_png_base64_url(image)}}, + ] + messages.append(user_message(image_content)) + + return messages + + def get_action(self, obs: dict) -> tuple[ToolCall, dict]: + if self.max_actions_reached(): + logger.warning("Max actions reached, stopping agent.") + return ToolCall(name="final_step"), {} + + self.history += self.obs_to_messages(obs) + self.maybe_compact_history() + messages = self.history + [{"role": "user", "content": self.config.guidance}] + + try: + logger.info(colored(f"Prompt:\n{pprint.pformat([str(m)[:500] for m in messages], width=120)}", "blue")) + response = self.llm(tools=self.tools, messages=messages) + message = response.choices[0].message # type: ignore + except Exception as e: + logger.exception(f"Error getting LLM response: {e}. Prompt: {messages}") + raise e + logger.info(colored(f"LLM response:\n{pprint.pformat(message, width=120)}", "green")) + + self.history.append(message) + thoughts = self.thoughts_from_message(message) + action = self.action_from_message(message) + return action, {"think": thoughts, "chat_messages": self.history} + + def max_actions_reached(self) -> bool: + prev_actions = [msg for msg in self.history if isinstance(msg, Message) and msg.tool_calls] + return len(prev_actions) >= self.config.max_actions + + def thoughts_from_message(self, message: Message) -> str: + """Extract the agent's thoughts from the LLM message.""" + thoughts = [] + if reasoning := message.get("reasoning_content"): + thoughts.append(reasoning) + if blocks := message.get("thinking_blocks"): + for block in blocks: + if thinking := getattr(block, "content", None) or getattr(block, "thinking", None): + thoughts.append(thinking) + if message.content: + thoughts.append(message.content) + logger.info(colored(f"LLM thoughts: {thoughts}", "cyan")) + return "\n\n".join(thoughts) + + def action_from_message(self, message: Message) -> ToolCall: + """Parse the ToolCall from the LLM message.""" + if message.tool_calls: + if len(message.tool_calls) > 1: + logger.warning("Multiple tool calls found in LLM response, using the first one.") + tool_call = message.tool_calls[0] + name = tool_call.function.name + assert name, "Tool call must have a name." + args = json.loads(tool_call.function.arguments) + action = ToolCall(id=tool_call.id, name=name, arguments=args) + self.last_tool_call_id = action.id + logger.info(colored(f"Parsed tool call: {action}", "magenta")) + else: + raise ValueError(f"No tool call found in LLM response: {message}") + return action + + def maybe_compact_history(self): + tokens = self.token_counter(messages=self.history) + if tokens > self.config.max_history_tokens: + logger.info("Compacting history due to length.") + self.compact_history() + short_tokens = self.token_counter(messages=self.history) + logger.info(f"Compacted history from {tokens} to {short_tokens} tokens.") + + def compact_history(self): + """ + Compact the history by summarizing the first half of messages with the LLM. + Updates self.history in place by replacing the first half with the summary message. + """ + system_msg = self.history[0] + rest = self.history[1:] + midpoint = len(rest) // 2 + messages = [ + {"role": "system", "content": self.config.summarize_system_prompt}, + *rest[:midpoint], + {"role": "user", "content": self.config.summarize_prompt}, + ] + + try: + response = self.llm(messages=messages) + summary = response.choices[0].message.content # type: ignore + except Exception as e: + logger.exception(f"Error compacting history: {e}") + raise + + logger.info(colored(f"Compacted {midpoint} messages into summary:\n{summary}", "cyan")) + # Rebuild history: system + summary + remaining messages + summary_message = {"role": "user", "content": f"## Previous Interaction :\n{summary}"} + self.history = [system_msg, summary_message, *rest[midpoint:]] + + def get_training_pairs(self) -> list[tuple[list[dict | Message], Message]]: + input_output_pairs = [] + prev_history = [] + for msg in self.history: + if isinstance(msg, Message): + input_output_pairs.append((prev_history, msg)) + prev_history.append(msg) + return input_output_pairs + +@dataclass +class ReactToolCallAgentArgs(AgentArgs): + llm_args: LLMArgs = None # type: ignore + config: AgentConfig = None # type: ignore + + def make_agent(self, actions: list[ToolSpec]) -> ReactToolCallAgent: + llm = self.llm_args.make_model() + counter = partial(token_counter, model=self.llm_args.model_name) + action_set = ToolsActionSet(actions=actions) + return ReactToolCallAgent(action_set, llm, counter, self.config) diff --git a/src/agentlab/backends/__init__.py b/src/agentlab/backends/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agentlab/backends/browser/__init__.py b/src/agentlab/backends/browser/__init__.py new file mode 100644 index 00000000..9fc3f071 --- /dev/null +++ b/src/agentlab/backends/browser/__init__.py @@ -0,0 +1,17 @@ +from agentlab.backends.browser.base import AsyncBrowserBackend, BrowserBackend +from agentlab.backends.browser.env import BrowserEnv, BrowserEnvArgs +from agentlab.backends.browser.mcp import MCPBrowserBackend, MCPClient +from agentlab.backends.browser.mcp_playwright import MCPPlaywright +from agentlab.backends.browser.playwright import AsyncPlaywright, SyncPlaywright + +__all__ = [ + "BrowserBackend", + "AsyncBrowserBackend", + "BrowserEnv", + "BrowserEnvArgs", + "MCPBrowserBackend", + "MCPClient", + "MCPPlaywright", + "AsyncPlaywright", + "SyncPlaywright", +] diff --git a/src/agentlab/backends/browser/base.py b/src/agentlab/backends/browser/base.py new file mode 100644 index 00000000..d11d0d5f --- /dev/null +++ b/src/agentlab/backends/browser/base.py @@ -0,0 +1,102 @@ +import logging +from abc import ABC, abstractmethod + +from PIL import Image +from pydantic import BaseModel + +from agentlab.actions import ToolCall, ToolSpec + +logger = logging.getLogger(__name__) + + +class BrowserBackend(BaseModel, ABC): + has_pw_page: bool = False + + @abstractmethod + def initialize(self) -> None: + pass + + @abstractmethod + def evaluate_js(self, js: str) -> str | dict | list: + return "" + + @abstractmethod + def goto(self, url: str) -> str: + pass + + @abstractmethod + def page_html(self) -> str: + pass + + @abstractmethod + def page_screenshot(self) -> Image.Image: + pass + + @abstractmethod + def page_axtree(self) -> str: + pass + + @abstractmethod + def step(self, action: ToolCall) -> dict: + pass + + @abstractmethod + def actions(self) -> list[ToolSpec]: + pass + + @abstractmethod + def close(self) -> None: + pass + + @property + def page(self): + raise NotImplementedError("Direct access to the playwright page is not supported.") + + +class AsyncBrowserBackend(BaseModel): + """Abstract base class for async browser backends.""" + + has_pw_page: bool = False + + class Config: + arbitrary_types_allowed = True + + @abstractmethod + async def initialize(self) -> None: + pass + + @abstractmethod + async def evaluate_js(self, js: str) -> str | dict | list: + pass + + @abstractmethod + async def goto(self, url: str) -> None: + pass + + @abstractmethod + async def page_html(self) -> str: + pass + + @abstractmethod + async def page_screenshot(self) -> Image.Image: + pass + + @abstractmethod + async def page_axtree(self) -> str: + pass + + @abstractmethod + async def step(self, action: ToolCall) -> dict: + pass + + @abstractmethod + def actions(self) -> list[ToolSpec]: + pass + + @abstractmethod + async def close(self) -> None: + pass + + @property + def page(self): + raise NotImplementedError("Direct access to the playwright page is not supported.") diff --git a/src/agentlab/backends/browser/env.py b/src/agentlab/backends/browser/env.py new file mode 100644 index 00000000..bc9410f2 --- /dev/null +++ b/src/agentlab/backends/browser/env.py @@ -0,0 +1,142 @@ +import logging +import time +from dataclasses import dataclass +from pathlib import Path + +from agentlab.actions import ToolCall, ToolsActionSet, ToolSpec +from agentlab.backends.browser.base import BrowserBackend +from agentlab.benchmarks.abstract_env import AbstractEnv, AbstractEnvArgs +from agentlab.benchmarks.web_task import AbstractWebTask + +logger = logging.getLogger(__name__) + + +def final_step(): + """ + Finish the task execution. + """ + return { + "pruned_html": "Task finished", + "axtree_txt": "", + "last_action_error": "", + "focused_element_bid": "", + } + + +class BrowserEnv(AbstractEnv): + def __init__( + self, task_name: str, task: AbstractWebTask, backend: BrowserBackend, seed: int = 0 + ): + self.task_name = task_name + self.task = task + self.seed = seed + self._turns = 0 + self.backend = backend + self.backend.initialize() + self.goal = "" + + def reset(self, seed: int): + self.seed = seed + self.goal, task_info = self.task.setup(backend=self.backend) + obs = self._get_obs() + obs = self.task.obs_postprocess(obs) + return obs, task_info + + def _get_obs(self) -> dict: + html = self.backend.page_html() + screenshot = self.backend.page_screenshot() + axtree = self.backend.page_axtree() + obs = { + "goal_object": [{"type": "text", "text": self.goal}], + "html": html, + "axtree_txt": axtree, + "screenshot": screenshot, + "last_action_error": "", + "focused_element_bid": "", + } + return obs + + def step(self, action: ToolCall | str) -> tuple[dict, float, bool, bool, dict]: + if isinstance(action, str): + action = ToolsActionSet.parse_action(action) + logger.info(f"BrowserEnv.step() called with action {action}") + + action_exec_start = time.time() + done = action.name == "final_step" + if done: + observation = final_step() + else: + observation = self.backend.step(action) + action_exec_stop = time.time() + self._turns += 1 + if isinstance(self.task, AbstractWebTask): + truncated = self._turns >= self.task.max_turns + else: + truncated = False + + observation = self.obs_postprocess(observation) + + reward, info = self.task.validate() + if info.get("done", False): + done = True + + env_info = { + **info, + "action_exec_start": action_exec_start, + "action_exec_stop": action_exec_stop, + "action_exec_timeout": 0.0, + } + logger.info(f"Action result in observation: {observation}") + return observation, reward, done, truncated, env_info + + def obs_postprocess(self, obs: dict) -> dict: + if "goal_object" not in obs: + obs["goal_object"] = [{"type": "text", "text": self.goal}] + if "last_action_error" not in obs: + obs["last_action_error"] = "" + if "focused_element_bid" not in obs: + obs["focused_element_bid"] = "" + if isinstance(self.task, AbstractWebTask): + obs = self.task.obs_postprocess(obs) + return obs + + def close(self): + self.task.teardown() + + def actions(self) -> list[ToolSpec]: + all_actions = self.backend.actions() + if isinstance(self.task, AbstractWebTask): + filtered_actions = self.task.filter_actions(all_actions) + logger.info( + f"Filtered {len(filtered_actions)} actions out of {len(all_actions)} for dataset {self.task.dataset}" + ) + else: + filtered_actions = all_actions + final_step_action = ToolSpec.from_function(final_step) + return filtered_actions + [final_step_action] + + +@dataclass +class BrowserEnvArgs(AbstractEnvArgs): + task: AbstractWebTask + task_seed: int + task_name: str + backend_cls: type[BrowserBackend] + + def __init__( + self, + task: AbstractWebTask, + backend_cls: type[BrowserBackend], + task_seed: int = 0, + ): + self.task_name = f"{task.dataset}.{task.task_id}" + self.task = task + self.task_seed = task_seed + self.backend_cls = backend_cls + + def make_env(self, exp_dir: Path) -> BrowserEnv: + backend = self.backend_cls() + env = BrowserEnv( + task_name=self.task_name, task=self.task, backend=backend, seed=self.task_seed + ) + return env diff --git a/src/agentlab/backends/browser/mcp.py b/src/agentlab/backends/browser/mcp.py new file mode 100644 index 00000000..6040d532 --- /dev/null +++ b/src/agentlab/backends/browser/mcp.py @@ -0,0 +1,175 @@ +import asyncio +import json +import logging +import os +from contextlib import AsyncExitStack +from datetime import timedelta +from typing import Any + +from mcp import ClientSession, StdioServerParameters, stdio_client +from mcp import Tool as MCPTool +from mcp.types import CallToolResult, ContentBlock, TextContent + +from agentlab.actions import FunctionSpec, ToolCall, ToolSpec +from agentlab.backends.browser.base import BrowserBackend + +logger = logging.getLogger(__name__) + + +class MCPClient: + def __init__(self, config_path: str, read_timeout_seconds: int = 10) -> None: + self.servers = self.load_config(config_path) + self.sessions: dict[str, ClientSession] = {} + self.tools: dict[str, MCPTool] = {} + self.tool_to_server: dict[str, str] = {} + self.read_timeout_seconds = read_timeout_seconds + self.exit_stack = AsyncExitStack() + self.loop: asyncio.AbstractEventLoop + + def initialize(self): + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_until_complete(self.start_servers()) + + async def ainitialize(self) -> None: + await self.start_servers() + + async def start_servers(self): + for server_name, server_params in self.servers.items(): + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + stdio, write = stdio_transport + session = await self.exit_stack.enter_async_context( + ClientSession( + stdio, write, read_timeout_seconds=timedelta(seconds=self.read_timeout_seconds) + ) + ) + await session.initialize() + self.sessions[server_name] = session + response = await session.list_tools() + for tool in response.tools: + if tool.name in self.tools: + raise Exception( + f"Tools conflict! Tool {tool.name} already provided by server '{self.tool_to_server[tool.name]}'" + ) + self.tools[tool.name] = tool + self.tool_to_server[tool.name] = server_name + logger.info( + f"Connected to MCP server '{server_name}' with tools: {[tool.name for tool in response.tools]}" + ) + logger.info(f"Started {len(self.servers)} MCP servers") + + def load_config(self, config_path) -> dict[str, StdioServerParameters]: + assert os.path.exists(config_path), f"Config path {config_path} does not exist" + self.config_path = config_path + + try: + with open(config_path, "r") as f: + self.config = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse {config_path}, invalid json: {e}") + try: + server_configs: dict[str, dict] = self.config["mcpServers"] + assert isinstance(server_configs, dict), "mcpServers must be a dict" + assert len(server_configs) > 0, "mcpServers dict is empty" + except Exception as e: + raise ValueError(f"Failed to get MCP server configs from {config_path}: {e}") + + servers: dict[str, StdioServerParameters] = {} + for server_name, server_config_dict in server_configs.items(): + try: + server_config_dict = self.prepare_env_vars(server_config_dict) + server_params = StdioServerParameters.model_validate(server_config_dict) + except Exception as e: + raise ValueError(f"Failed to parse server config {server_config_dict}: {e}") + servers[server_name] = server_params + logger.info(f"Loaded {len(servers)} MCP server configs from {config_path}") + return servers + + def prepare_env_vars(self, server_config_dict: dict) -> dict: + if server_env := server_config_dict.get("env"): + for env_var, env_value in server_env.items(): + if ( + env_var in os.environ and not env_value + ): # reuse existing env var value if not set in config + logger.info(f"Set mcp server env var {env_var} from current environment") + server_config_dict["env"][env_var] = os.environ[env_var] + return server_config_dict + + def call_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: + result = self.loop.run_until_complete(self.acall_tool(tool_name, tool_args)) + return result + + async def acall_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: + server_name = self.check_tool_exists(tool_name) + result = await self._call_tool(server_name, tool_name, tool_args) + return result + + async def _call_tool( + self, server_name: str, tool_name: str, tool_args: dict[str, Any] + ) -> CallToolResult: + try: + session = self.sessions[server_name] + result = await session.call_tool(tool_name, tool_args) + except Exception as e: + logger.exception(f"Error calling tool {tool_name}: {e}") + raise e + return result + + def check_tool_exists(self, tool_name): + try: + server_name = self.tool_to_server[tool_name] + except KeyError: + raise Exception(f"Tool {tool_name} not found in any of the MCP servers") + return server_name + + def actions(self) -> list[ToolSpec]: + return [ + ToolSpec( + function=FunctionSpec( + name=tool.name, description=tool.description or "", parameters=tool.inputSchema + ) + ) + for tool in self.tools.values() + ] + + async def aclose(self) -> None: + await self.exit_stack.aclose() + + def close(self) -> None: + self.loop.run_until_complete(self.aclose()) + + +class MCPBrowserBackend(BrowserBackend): + config_path: str + _mcp: MCPClient + + def initialize(self) -> None: + self._mcp = MCPClient(config_path=self.config_path) + self._mcp.initialize() + + def step(self, action: ToolCall) -> dict: + contents = self.call_tool(action.name, action.arguments) + action_result = "\n\n".join([c.text for c in contents if c.type == "text"]) + images = {f"image_{i}":c for i,c in enumerate(contents) if c.type == "image"} + return { + "action_result": action_result, + **images, + } + + def call_tool(self, tool_name: str, arguments: dict) -> list[ContentBlock]: + tool_result = self._mcp.call_tool(tool_name, arguments) + if tool_result.isError: + return [TextContent(type="text", text=f"Error calling tool {tool_name}")] + tool_result.content + return tool_result.content + + def actions(self) -> list[ToolSpec]: + return list(self._mcp.actions()) + + def close(self) -> None: + try: + self._mcp.close() + except Exception: + pass diff --git a/src/agentlab/backends/browser/mcp_playwright.json b/src/agentlab/backends/browser/mcp_playwright.json new file mode 100644 index 00000000..b79e4f77 --- /dev/null +++ b/src/agentlab/backends/browser/mcp_playwright.json @@ -0,0 +1,19 @@ +{ + "mcpServers": { + "playwright": { + "command": "npx", + "args": [ + "@playwright/mcp@latest", + "--browser", + "firefox", + "--headless", + "--isolated", + "--caps", + "vision" + ], + "env": { + "PLAYWRIGHT_BROWSERS_PATH": "" + } + } + } +} diff --git a/src/agentlab/backends/browser/mcp_playwright.py b/src/agentlab/backends/browser/mcp_playwright.py new file mode 100644 index 00000000..d2edec85 --- /dev/null +++ b/src/agentlab/backends/browser/mcp_playwright.py @@ -0,0 +1,72 @@ +import base64 +import logging +from io import BytesIO + +from PIL import Image + +from agentlab.actions import ToolCall +from agentlab.backends.browser.mcp import MCPBrowserBackend + +logger = logging.getLogger(__name__) + +DEFAULT_CONFIG_PATH = "src/agentlab/backends/browser/mcp_playwright.json" + + +class MCPPlaywright(MCPBrowserBackend): + config_path: str = DEFAULT_CONFIG_PATH + + def evaluate_js(self, js: str): + contents = self.call_tool("browser_evaluate", {"function": js}) + raw_response = "\n".join([c.text for c in contents if c.type == "text"]) + try: + _, half_response = raw_response.split("### Result", maxsplit=1) + result_str, _ = half_response.split("\n### Ran", maxsplit=1) + result_str = result_str.strip() + except Exception as e: + logger.error(f"Error parsing JS result: {e}. Raw result: {raw_response}") + raise e + return result_str + + def step(self, action: ToolCall) -> dict: + contents = self.call_tool(action.name, action.arguments) + logger.info(f"Step result has {len(contents)} contents") + action_result = "\n".join( + [c.text for c in contents if c.type == "text" and "# Ran Playwright code" not in c.text] + ) + html = self.page_html() + screenshot = self.page_screenshot() + axtree = self.page_axtree() + return { + "action_result": action_result, + "html": html, + "axtree_txt": axtree, + "screenshot": screenshot, + } + + def page_html(self) -> str: + contents = self.call_tool( + "browser_evaluate", {"function": "document.documentElement.outerHTML"} + ) + raw_response = "\n".join([c.text for c in contents if c.type == "text"]) + try: + _, half_response = raw_response.split("### Result", maxsplit=1) + result_str, _ = half_response.split("\n### Ran", maxsplit=1) + return result_str.strip() + except Exception as e: + logger.error(f"Error parsing page_html result: {e}. Raw result: {raw_response}") + return "" + + def page_axtree(self) -> str: + contents = self.call_tool("browser_snapshot", {}) + return "\n".join([c.text for c in contents if c.type == "text"]) + + def page_screenshot(self) -> Image.Image: + contents = self.call_tool("browser_take_screenshot", {}) + content = [c for c in contents if c.type == "image"][0] + image_base64 = content.data + image = Image.open(BytesIO(base64.b64decode(image_base64))) + return image + + def goto(self, url: str) -> str: + contents = self.call_tool("browser_navigate", {"url": url}) + return "\n".join([c.text for c in contents if c.type == "text"]) diff --git a/src/agentlab/backends/browser/playwright.py b/src/agentlab/backends/browser/playwright.py new file mode 100644 index 00000000..01a306bf --- /dev/null +++ b/src/agentlab/backends/browser/playwright.py @@ -0,0 +1,298 @@ +import logging +import time +from io import BytesIO +from typing import Any, Callable + +from PIL import Image +from playwright.async_api import Page as AsyncPage +from playwright.async_api import async_playwright +from playwright.sync_api import Page as SyncPage +from playwright.sync_api import sync_playwright + +from agentlab.actions import ToolCall, ToolSpec +from agentlab.backends.browser.base import AsyncBrowserBackend, BrowserBackend + +logger = logging.getLogger(__name__) + + +_pw = None # Global Playwright instance for SyncPlaywright +_browser = None # Global Browser instance for SyncPlaywright + + +class SyncPlaywright(BrowserBackend): + """Fully synchronous Playwright backend using playwright.sync_api.""" + + has_pw_page: bool = True + _actions: dict[str, Callable] + _page: SyncPage + + def model_post_init(self, __context: Any): + self._actions = { + "browser_press_key": self.browser_press_key, + "browser_type": self.browser_type, + "browser_click": self.browser_click, + "browser_drag": self.browser_drag, + "browser_hover": self.browser_hover, + "browser_select_option": self.browser_select_option, + "browser_mouse_click_xy": self.browser_mouse_click_xy, + } + + def initialize(self): + global _pw, _browser + if _pw is None: + _pw = sync_playwright().start() + if _browser is None: + _browser = _pw.chromium.launch(headless=True, chromium_sandbox=True) + + self._page = _browser.new_page() + + @property + def page(self) -> SyncPage: + return self._page + + def browser_press_key(self, key: str): + """Press a key on the keyboard.""" + self._page.keyboard.press(key) + + def browser_type(self, selector: str, text: str): + """Type text into the focused element.""" + self._page.type(selector, text) + + def browser_click(self, selector: str): + """Click on a selector.""" + self._page.click(selector, timeout=3000, strict=True) + + def browser_drag(self, from_selector: str, to_selector: str): + """Drag and drop from one selector to another.""" + from_elem = self._page.locator(from_selector) + from_elem.hover(timeout=500) + self._page.mouse.down() + + to_elem = self._page.locator(to_selector) + to_elem.hover(timeout=500) + self._page.mouse.up() + + def browser_hover(self, selector: str): + """Hover over a given element.""" + self._page.hover(selector, timeout=3000, strict=True) + + def browser_select_option(self, selector: str, value: str): + """Select an option from a given element.""" + self._page.select_option(selector, value) + + def browser_mouse_click_xy(self, x: int, y: int): + """Click at a given x, y coordinate using the mouse.""" + self._page.mouse.click(x, y, delay=100) + + def browser_wait(self, seconds: int): + """Wait for a given number of seconds, up to 10 seconds.""" + time.sleep(min(seconds, 10)) + + def evaluate_js(self, js: str): + js_result = self._page.evaluate(js) + logger.info(f"JS result: {js_result}") + return js_result + + def goto(self, url: str): + """Navigate to a specified URL.""" + self._page.goto(url) + + def browser_back(self): + """Navigate back in browser history.""" + self._page.go_back() + + def browser_forward(self): + """Navigate forward in browser history.""" + self._page.go_forward() + + def page_html(self) -> str: + return self._page.content() + + def page_screenshot(self) -> Image.Image: + scr_bytes = self._page.screenshot() + return Image.open(BytesIO(scr_bytes)) + + def page_axtree(self) -> str: + axtree = self._page.accessibility.snapshot() + return flatten_axtree(axtree) + + def step(self, action: ToolCall) -> dict: + fn = self._actions[action.name] + try: + action_result = fn(**action.arguments) + except Exception as e: + action_result = f"Error executing action {action.name}: {e}" + logger.error(action_result) + html = self.page_html() + screenshot = self.page_screenshot() + axtree = self.page_axtree() + return { + "action_result": action_result, + "html": html, + "axtree_txt": axtree, + "screenshot": screenshot, + } + + def actions(self) -> list[ToolSpec]: + return [ToolSpec.from_function(fn) for fn in self._actions.values()] + + def close(self): + self._page.close() + + +_apw = None # Global Playwright instance for AsyncPlaywright +_abrowser = None # Global Browser instance for AsyncPlaywright + + +class AsyncPlaywright(AsyncBrowserBackend): + """Fully asynchronous Playwright backend using playwright.async_api.""" + + has_pw_page: bool = False + _actions: dict[str, Callable] + _page: AsyncPage + + def model_post_init(self, __context: Any): + self._actions = { + "browser_press_key": self.browser_press_key, + "browser_type": self.browser_type, + "browser_click": self.browser_click, + "browser_drag": self.browser_drag, + "browser_hover": self.browser_hover, + "browser_select_option": self.browser_select_option, + "browser_mouse_click_xy": self.browser_mouse_click_xy, + } + + async def initialize(self): + global _apw, _abrowser + if _apw is None: + _apw = await async_playwright().start() + if _abrowser is None: + _abrowser = await _apw.chromium.launch(headless=False, chromium_sandbox=True) + self._page = await _abrowser.new_page() + + async def browser_press_key(self, key: str): + """Press a key on the keyboard.""" + await self._page.keyboard.press(key) + + async def browser_type(self, selector: str, text: str): + """Type text into the focused element.""" + await self._page.type(selector, text) + + async def browser_click(self, selector: str): + """Click on a selector.""" + await self._page.click(selector, timeout=3000, strict=True) + + async def browser_drag(self, from_selector: str, to_selector: str): + """Drag and drop from one selector to another.""" + from_elem = self._page.locator(from_selector) + await from_elem.hover(timeout=500) + await self._page.mouse.down() + + to_elem = self._page.locator(to_selector) + await to_elem.hover(timeout=500) + await self._page.mouse.up() + + async def browser_hover(self, selector: str): + """Hover over a given element.""" + await self._page.hover(selector, timeout=3000, strict=True) + + async def browser_select_option(self, selector: str, value: str): + """Select an option from a given element.""" + await self._page.select_option(selector, value) + + async def browser_mouse_click_xy(self, x: int, y: int): + """Click at a given x, y coordinate using the mouse.""" + await self._page.mouse.click(x, y, delay=100) + + async def evaluate_js(self, js: str): + js_result = await self._page.evaluate(js) + logger.info(f"JS result: {js_result}") + return js_result + + async def goto(self, url: str): + await self._page.goto(url) + + async def page_html(self) -> str: + return await self._page.content() + + async def page_screenshot(self) -> Image.Image: + scr_bytes = await self._page.screenshot() + return Image.open(BytesIO(scr_bytes)) + + async def page_axtree(self) -> str: + axtree = await self._page.accessibility.snapshot() + return flatten_axtree(axtree) + + async def step(self, action: ToolCall) -> dict: + fn = self._actions[action.name] + try: + action_result = await fn(**action.arguments) + except Exception as e: + action_result = f"Error executing action {action.name}: {e}" + logger.error(action_result) + html = await self.page_html() + screenshot = await self.page_screenshot() + axtree = await self.page_axtree() + return { + "action_result": action_result, + "html": html, + "axtree_txt": axtree, + "screenshot": screenshot, + } + + def actions(self) -> list[ToolSpec]: + return [ToolSpec.from_function(fn) for fn in self._actions.values()] + + async def close(self): + await self._page.close() + + +def flatten_axtree(axtree_dict: dict | None) -> str: + """ + Traverses accessibility tree dictionary and returns its markdown view. + + Args: + axtree_dict: Accessibility tree from playwright page.accessibility.snapshot() + Structure: dict with 'role', 'name', 'value', 'children' keys + + Returns: + String representation of the accessibility tree in markdown format + """ + if axtree_dict is None: + return "" + + def traverse_node(node: dict, depth: int = 0) -> list[str]: + """Recursively traverse the accessibility tree and build markdown lines.""" + lines = [] + indent = " " * depth # 2 spaces per indent level + + # Extract node information + role = node.get("role", "") + name = node.get("name", "") + value = node.get("value", "") + + # Build the node representation + parts = [] + if role: + parts.append(f"{role}:") + if name.strip(): + parts.append(f"{name}") + if value: + parts.append(f"[value: {value}]") + + # Only add line if there's meaningful content + if parts: + line = f"{indent}{' '.join(parts)}" + lines.append(line) + + # Recursively process children + children = node.get("children", []) + for child in children: + child_lines = traverse_node(child, depth + 1) + lines.extend(child_lines) + + return lines + + # Start traversal from root + all_lines = traverse_node(axtree_dict, depth=0) + return "\n".join(all_lines) diff --git a/src/agentlab/benchmarks/miniwob/__init__.py b/src/agentlab/benchmarks/miniwob/__init__.py new file mode 100644 index 00000000..7b2add6f --- /dev/null +++ b/src/agentlab/benchmarks/miniwob/__init__.py @@ -0,0 +1,4 @@ +from .benchmark import MiniWobBenchmark +from .task import MiniWobTask + +__all__ = ["MiniWobBenchmark", "MiniWobTask"] diff --git a/src/agentlab/benchmarks/miniwob/benchmark.py b/src/agentlab/benchmarks/miniwob/benchmark.py new file mode 100644 index 00000000..2ce01895 --- /dev/null +++ b/src/agentlab/benchmarks/miniwob/benchmark.py @@ -0,0 +1,33 @@ +import logging +from typing import Any + +from pydantic import ConfigDict + +from agentlab.actions import ToolsActionSet +from agentlab.backends.browser.base import BrowserBackend +from agentlab.backends.browser.env import BrowserEnvArgs +from agentlab.benchmarks.abstract_env import AbstractBenchmark +from agentlab.benchmarks.miniwob.task import MiniWobTask, get_miniwob_tasks + +logger = logging.getLogger(__name__) + + +class MiniWobBenchmark(AbstractBenchmark): + model_config = ConfigDict(arbitrary_types_allowed=True) + + backend_cls: type[BrowserBackend] + name: str = "miniwob" + env_args_list: list[BrowserEnvArgs] = None # type: ignore + dataset: list[MiniWobTask] = None # type: ignore + is_multi_tab: bool = False + high_level_action_set_args: ToolsActionSet = None # type: ignore + + def model_post_init(self, __context: Any) -> None: + self.name = f"miniwob_{self.backend_cls.__name__.lower()}" + self.env_args_list = [] + if self.dataset is None: + self.dataset = get_miniwob_tasks() + for task in self.dataset: + env_args = BrowserEnvArgs(task=task, backend_cls=self.backend_cls) + self.env_args_list.append(env_args) + logger.info(f"Loaded {len(self.env_args_list)} miniwob tasks") diff --git a/src/agentlab/benchmarks/miniwob/task.py b/src/agentlab/benchmarks/miniwob/task.py new file mode 100644 index 00000000..019711d4 --- /dev/null +++ b/src/agentlab/benchmarks/miniwob/task.py @@ -0,0 +1,222 @@ +import logging +import os +from typing import Any, ClassVar + +from browsergym.miniwob import ALL_MINIWOB_TASKS +from browsergym.utils.obs import prune_html + +from agentlab.backends.browser import BrowserBackend +from agentlab.benchmarks.web_task import AbstractWebTask + +logger = logging.getLogger(__name__) + + +class MiniWobTask(AbstractWebTask): + dataset: str = "miniwob" + desc: str + subdomain: str + base_url: str = None # type: ignore + url: str = None # type: ignore + remove_human_display: bool = True + episode_max_time: int = 1000000 + max_turns: int = 10 + validate_per_step: bool = True + actions_whitelist: ClassVar[list[str]] = [ + "browser_press_key", + "browser_type", + "browser_click", + "browser_drag", + "browser_hover", + "browser_select_option", + "browser_mouse_click_xy", + ] + + def model_post_init(self, __context: Any): + if self.base_url.endswith("/"): + self.base_url = self.base_url[:-1] + self.url = f"{self.base_url}/{self.subdomain}.html" + + + def setup(self, backend: BrowserBackend) -> tuple[str, dict]: + """ + Set up everything needed to execute the task. + + Args: + page: the active playwright page. + + Returns: + goal: str, goal of the task. + info: dict, custom information from the task. + """ + backend.goto(self.url) + setup_js = self._get_setup_js() + setup_result = backend.evaluate_js(setup_js) + goal, info = self._parse_setup_result(setup_result) + self._backend = backend + return goal, info + + def teardown(self) -> None: + """ + Tear down the task, clean up resources if needed. + + Args: + page: the active playwright page. + """ + teardown_js = self._get_teardown_js() + if teardown_js: + self._backend.evaluate_js(teardown_js) + + def validate(self) -> tuple[float, dict]: + """ + Validate the task, either per step or at the end. + + Returns: + reward: float, the reward obtained. + info: dict, custom information from the validation. + """ + validate_js = ( + self._get_step_validate_js() + if self.validate_per_step + else self._get_task_validate_js() + ) + validate_result = self._backend.evaluate_js(validate_js) + reward, info = self._parse_validation_result(validate_result) + return reward, info + + def _get_setup_js(self) -> str: + if self.remove_human_display: + logger.info("Remove human display") + js = r""" +let __display_ids = ['reward-display', 'click-canvas', 'sync-task-cover']; +let __display_divs = {}; +let __query_div_hidden_copy = null; + +removeDisplay = function() { + core.clearTimer(); + document.body.removeEventListener('click', core.canvasDrawClick); + + __query_div_hidden_copy = document.getElementById('query').cloneNode(true); + document.getElementById('query').innerHTML = ''; + + for (i in __display_ids) { + elem_id = __display_ids[i]; + elem = document.getElementById(elem_id); + // remove elem from the document + elem.remove(); + // but keep it stored somewhere to bring back later + __display_divs[elem_id] = elem; + } +}; + +bringBackDisplay = function() { + document.getElementById('query').innerHTML = __query_div_hidden_copy.innerHTML; + for (var elem_id in __display_divs){ + document.body.appendChild(__display_divs[elem_id]); + } + core.createDisplay(); +}; + +core.endEpisode_legacy = core.endEpisode; +core.startEpisodeReal_legacy = core.startEpisodeReal; +core.getUtterance_legacy = core.getUtterance; + +core.getUtterance = function () { + bringBackDisplay(); + utterance = core.getUtterance_legacy(); + removeDisplay(); + return utterance; +}; + +core.endEpisode = function(reward, time_proportional, reason){ + bringBackDisplay(); + core.endEpisode_legacy(reward, time_proportional, reason); + removeDisplay(); +}; + +core.startEpisodeReal = function() { + bringBackDisplay(); + core.startEpisodeReal_legacy(); + removeDisplay(); +}; + +removeDisplay(); +""" + else: + js = "" + js += f""" +Math.seedrandom(42); +core.EPISODE_MAX_TIME = {self.episode_max_time}; +core.startEpisodeReal(); +while (!WOB_TASK_READY) {{ + await new Promise(resolve => setTimeout(resolve, 100)); +}} +return core.getUtterance(); + """ + return f"async () => {{{js}}}" + + def _parse_setup_result(self, setup_result: str | dict | list) -> tuple[str, dict]: + if isinstance(setup_result, dict): + return setup_result["utterance"], {} + elif isinstance(setup_result, str): + return setup_result, {} + else: + raise ValueError(f"Unexpected setup_result type: {type(setup_result)}") + + def _get_teardown_js(self) -> str: + return "" + + def _get_step_validate_js(self) -> str: + return """() => { +return [WOB_REWARD_GLOBAL, WOB_RAW_REWARD_GLOBAL, WOB_REWARD_REASON, WOB_DONE_GLOBAL, WOB_EPISODE_ID, WOB_TASK_READY]; +}""" + + def _get_task_validate_js(self) -> str: + return """() => { +return [WOB_REWARD_GLOBAL, WOB_RAW_REWARD_GLOBAL, WOB_REWARD_REASON, WOB_DONE_GLOBAL, WOB_EPISODE_ID, WOB_TASK_READY]; +}""" + + def _parse_validation_result(self, validation_result: str | dict | list) -> tuple[float, dict]: + if isinstance(validation_result, list): + chunks = validation_result + done = chunks[3] + elif isinstance(validation_result, dict): + raise ValueError("Validation result as dict is not supported") + else: + chunks = [c.strip() for c in validation_result.split(",")] + done = chunks[3].strip().lower() == "true" + raw_reward = float(chunks[1]) + reward = float(raw_reward > 0) + return reward, { + "raw_reward": raw_reward, + "reward_reason": chunks[2], + "done": done, + } + + def obs_postprocess(self, obs: dict) -> dict: + html = obs.pop("html", "") + obs["pruned_html"] = prune_html(html) + if screenshot := obs.get("screenshot", None): + obs["screenshot"] = screenshot.crop( + (0, 0, 332, 214) + ) # crop to 332x214 because this is the viewport size for MiniWob + return obs + + +def get_miniwob_tasks( + base_url: str | None = None, remove_human_display: bool = True, episode_max_time: int = 1000000 +) -> list[MiniWobTask]: + if base_url is None: + base_url = os.environ.get("MINIWOB_URL") + if base_url is None: + raise ValueError("MINIWOB_URL environment variable is not set") + return [ + MiniWobTask( + task_id=task.subdomain, + desc=task.desc, + subdomain=task.subdomain, + base_url=base_url, + remove_human_display=remove_human_display, + episode_max_time=episode_max_time, + ) + for task in ALL_MINIWOB_TASKS + ] diff --git a/src/agentlab/benchmarks/web_task.py b/src/agentlab/benchmarks/web_task.py new file mode 100644 index 00000000..f753828f --- /dev/null +++ b/src/agentlab/benchmarks/web_task.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from typing import ClassVar + +from pydantic import BaseModel + +from agentlab.actions import ToolSpec +from agentlab.backends.browser import BrowserBackend + + +class AbstractWebTask(BaseModel, ABC): + dataset: str + task_id: str + url: str + validate_per_step: bool = False + actions_whitelist: ClassVar[list[str]] = [] + max_turns: int = 100 + _backend: BrowserBackend = None # type: ignore + + def get_task_id(self) -> str: + return self.task_id + + @abstractmethod + def setup(self, backend: BrowserBackend) -> tuple[str, dict]: + """ + Set up everything needed to execute the task. + + Args: + page: the active playwright page. + + Returns: + goal: str, goal of the task. + info: dict, custom information from the task. + """ + + @abstractmethod + def teardown(self): + """ + Tear down the task, clean up resources if needed. + + Args: + page: the active playwright page. + """ + + @abstractmethod + def validate(self) -> tuple[float, dict]: + """ + Validate the task, either per step or at the end. + + Returns: + reward: float, the reward obtained. + info: dict, custom information from the validation. + """ + + def cheat(self): + """ + Solve the task using a pre-defined solution (optional). + """ + + @classmethod + def filter_actions(cls, actions: list[ToolSpec]) -> list[ToolSpec]: + filtered_actions = [action for action in actions if action.function.name in cls.actions_whitelist] + return filtered_actions + + def obs_postprocess(self, obs: dict) -> dict: + return obs \ No newline at end of file diff --git a/src/agentlab/benchmarks/workarena/__init__.py b/src/agentlab/benchmarks/workarena/__init__.py new file mode 100644 index 00000000..4b038f1c --- /dev/null +++ b/src/agentlab/benchmarks/workarena/__init__.py @@ -0,0 +1,4 @@ +from .benchmark import WorkArenaBenchmark +from .task import WorkarenaTask + +__all__ = ["WorkArenaBenchmark", "WorkarenaTask"] \ No newline at end of file diff --git a/src/agentlab/benchmarks/workarena/benchmark.py b/src/agentlab/benchmarks/workarena/benchmark.py new file mode 100644 index 00000000..725a55ea --- /dev/null +++ b/src/agentlab/benchmarks/workarena/benchmark.py @@ -0,0 +1,56 @@ +import logging +from typing import Any + +from browsergym.workarena import get_all_tasks_agents +from browsergym.workarena.instance import SNowInstance +from pydantic import ConfigDict + +from agentlab.actions import ToolsActionSet +from agentlab.backends.browser.base import BrowserBackend +from agentlab.backends.browser.env import BrowserEnvArgs +from agentlab.benchmarks.abstract_env import AbstractBenchmark + +from .task import WorkarenaTask + +logger = logging.getLogger(__name__) + + +class WorkArenaBenchmark(AbstractBenchmark): + model_config = ConfigDict(arbitrary_types_allowed=True) + + backend_cls: type[BrowserBackend] + name: str = "workarena" + level: str = "l1" + n_seeds: int = 1 + env_args_list: list[BrowserEnvArgs] = None # type: ignore + dataset: list[WorkarenaTask] = None # type: ignore + is_multi_tab: bool = False + high_level_action_set_args: ToolsActionSet = None # type: ignore + _snow_instance: SNowInstance = None # type: ignore + + def model_post_init(self, __context: Any) -> None: + self.name = f"workarena_{self.level}_{self.backend_cls.__name__.lower()}" + self._snow_instance = SNowInstance() + self.env_args_list = [] + if self.dataset is None: + self.dataset = self.load_tasks(self.level) + for task in self.dataset: + env_args = BrowserEnvArgs(task=task, backend_cls=self.backend_cls) + self.env_args_list.append(env_args) + logger.info(f"Loaded {len(self.env_args_list)} workarena tasks") + + def load_tasks(self, level: str) -> list[WorkarenaTask]: + task_seed_tuples = get_all_tasks_agents(filter=self.level, n_seed_l1=self.n_seeds) + tasks = [] + for task_cls, seed in task_seed_tuples: + task = WorkarenaTask( + url="", + task_id=task_cls.get_task_id(), + instance=self._snow_instance, + task_cls=task_cls, + level=level, + seed=seed, + ) + tasks.append(task) + logger.info(f"Loaded {len(tasks)} tasks for level {level}") + return tasks \ No newline at end of file diff --git a/src/agentlab/benchmarks/workarena/task.py b/src/agentlab/benchmarks/workarena/task.py new file mode 100644 index 00000000..d2d1efda --- /dev/null +++ b/src/agentlab/benchmarks/workarena/task.py @@ -0,0 +1,57 @@ +import logging +from typing import ClassVar + +from browsergym.utils.obs import prune_html +from browsergym.workarena.instance import SNowInstance +from browsergym.workarena.tasks.base import AbstractServiceNowTask +from pydantic import ConfigDict + +from agentlab.backends.browser import BrowserBackend +from agentlab.benchmarks.web_task import AbstractWebTask + +logger = logging.getLogger(__name__) + + +class WorkarenaTask(AbstractWebTask): + model_config = ConfigDict(arbitrary_types_allowed=True) + + dataset: str = "workarena" + level: str + task_cls: type[AbstractServiceNowTask] + seed: int + instance: SNowInstance + _task_obj: AbstractServiceNowTask = None # type: ignore + actions_whitelist: ClassVar[list[str]] = [ + "browser_press_key", + "browser_type", + "browser_select_option", + "browser_mouse_click_xy", + "browser_wait", + "browser_back", + "browser_forward", + ] + + def setup(self, backend: BrowserBackend) -> tuple[str, dict]: + if not backend.has_pw_page: + raise ValueError("Workarena task requires a backend with playwright page access.") + self._backend = backend + self._task_obj = self.task_cls(instance=self.instance, seed=self.seed) # type: ignore + self.url = self._task_obj.start_url + goal, info = self._task_obj.setup(backend.page) + backend.goto(self.url) + logger.info(f"Current backend page URL: {backend.page.url}") + + return goal, info + + def teardown(self) -> None: + self._task_obj.teardown() + + def validate(self) -> tuple[float, dict]: + reward, done, _, info = self._task_obj.validate(page=self._backend.page, chat_messages=[]) + info["done"] = done + return reward, info + + def obs_postprocess(self, obs: dict) -> dict: + html = obs.pop("html", "") + obs["pruned_html"] = prune_html(html) + return obs diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index de4b976a..4a8597dc 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -23,8 +23,11 @@ from browsergym.experiments.utils import count_tokens from dataclasses_json import DataClassJsonMixin from PIL import Image +from pydantic import BaseModel from tqdm import tqdm +from agentlab.backends.browser.env import BrowserEnvArgs + try: from agentlab.agents.tapeagent import TapeAgent, save_tape except ImportError: @@ -195,51 +198,24 @@ class StepInfo: profiling: StepTimestamps = field(default_factory=StepTimestamps) task_info: dict = None - def from_step(self, env: gym.Env, action: str, obs_preprocessor: callable): - t = self.profiling - t.env_start = time.time() - self.obs, self.reward, self.terminated, self.truncated, env_info = env.step(action) - t.env_stop = time.time() - + def add_action_result(self, action_result: tuple[dict, float, bool, bool, dict]): + self.obs, self.reward, self.terminated, self.truncated, env_info = action_result self.task_info = env_info.get("task_info", None) - self.raw_reward = env_info.get("RAW_REWARD_GLOBAL", None) - t.action_exec_start = env_info["action_exec_start"] # start - t.action_exect_after_timeout = env_info["action_exec_stop"] - t.action_exec_stop = env_info["action_exec_stop"] - env_info["action_exec_timeout"] - t.wait_for_page_loading_start = env_info.get("wait_for_page_loading_start", None) - t.wait_for_page_loading_stop = env_info.get("wait_for_page_loading_stop", None) - t.validation_start = env_info.get("validation_start", None) - t.validation_stop = env_info.get("validation_stop", None) - t.get_observation_start = env_info.get("get_observation_start", None) - t.get_observation_stop = env_info.get("get_observation_stop", None) - - if obs_preprocessor: - self.obs = obs_preprocessor(self.obs) - - def from_action(self, agent: Agent): - self.profiling.agent_start = time.time() - self.action, self.agent_info = agent.get_action(self.obs.copy()) - self.profiling.agent_stop = time.time() - - self.make_stats() - - return self.action - - def from_reset(self, env: gym.Env, seed: int, obs_preprocessor: callable): - t = self.profiling - t.env_start = time.time() - self.obs, env_info = env.reset(seed=seed) - self.reward, self.terminated, self.truncated = 0, False, False - t.env_stop = time.time() - - t.action_exec_start = env_info.get("recording_start_time", t.env_start) - t.action_exect_after_timeout = t.env_stop - t.action_exec_stop = t.env_stop - - if obs_preprocessor: - self.obs = obs_preprocessor(self.obs) + self.profiling.action_exec_start = env_info.get("action_exec_start", None) + self.profiling.action_exect_after_timeout = env_info["action_exec_stop"] + self.profiling.action_exec_stop = ( + env_info["action_exec_stop"] - env_info["action_exec_timeout"] + ) + self.profiling.wait_for_page_loading_start = env_info.get( + "wait_for_page_loading_start", None + ) + self.profiling.wait_for_page_loading_stop = env_info.get("wait_for_page_loading_stop", None) + self.profiling.validation_start = env_info.get("validation_start", None) + self.profiling.validation_stop = env_info.get("validation_stop", None) + self.profiling.get_observation_start = env_info.get("get_observation_start", None) + self.profiling.get_observation_stop = env_info.get("get_observation_stop", None) @property def is_done(self): @@ -262,7 +238,7 @@ def make_stats(self): self.stats = stats - def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_som=False): + def save(self, exp_dir, save_screenshot=True, save_som=False, save_json=False): # special treatment for some of the observation fields if isinstance(self.obs, dict): # save screenshots to separate files @@ -270,11 +246,17 @@ def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_so screenshot_som = self.obs.pop("screenshot_som", None) if save_screenshot and screenshot is not None: - img = Image.fromarray(screenshot) + if isinstance(screenshot, Image.Image): + img = screenshot + else: + img = Image.fromarray(screenshot) img.save(exp_dir / f"screenshot_step_{self.step}.png") if save_som and screenshot_som is not None: - img = Image.fromarray(screenshot_som) + if isinstance(screenshot_som, Image.Image): + img = screenshot_som + else: + img = Image.fromarray(screenshot_som) img.save(exp_dir / f"screenshot_som_step_{self.step}.png") # save goal object (which might contain images) to a separate file to save space @@ -289,14 +271,15 @@ def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_so with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f: pickle.dump(self, f) + logger.debug("Step info saved.") if save_json: with open(exp_dir / "steps_info.json", "w") as f: json.dump(self, f, indent=4, cls=DataclassJSONEncoder) + logger.debug("Step info saved to JSON.") if isinstance(self.obs, dict): # add the screenshots back to the obs - # why do we need this? if screenshot is not None: self.obs["screenshot"] = screenshot if screenshot_som is not None: @@ -414,58 +397,58 @@ def run(self): env, step_info, err_msg, stack_trace = None, None, None, None try: logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}") - agent = self.agent_args.make_agent() - if hasattr(agent, "set_task_name"): - agent.set_task_name(self.env_args.task_name) + env, agent = self.create_env_and_agent() - logger.debug("Agent created.") - - env = self.env_args.make_env( - action_mapping=agent.action_set.to_python_code, - exp_dir=self.exp_dir, - use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False), - ) - - logger.debug("Environment created.") step_info = StepInfo(step=0) - episode_info = [step_info] - step_info.from_reset( - env, seed=self.env_args.task_seed or 0, obs_preprocessor=agent.obs_preprocessor - ) + step_info.profiling.env_start = time.time() + step_info.obs, env_info = env.reset(seed=self.env_args.task_seed or 0) + step_info.profiling.env_stop = time.time() + step_info.task_info = env_info.get("task_info", None) + if agent.obs_preprocessor: + step_info.obs = agent.obs_preprocessor(step_info.obs) logger.debug("Environment reset.") while not step_info.is_done: # set a limit logger.debug(f"Starting step {step_info.step}.") - action = step_info.from_action(agent) - logger.debug(f"Agent chose action:\n {action}") + step_info.profiling.agent_start = time.time() + action, step_info.agent_info = agent.get_action(step_info.obs.copy()) + step_info.action = ( + action.model_dump_json(indent=2) + if isinstance(action, BaseModel) + else str(action) + ) + step_info.profiling.agent_stop = time.time() + if step_info.agent_info.get("think", None): + logger.info(f"Agent thought: {step_info.agent_info['think']}") + logger.debug(f"Agent action:\n {action}") if action is None: # will end the episode after saving the step info. step_info.truncated = True - step_info.save_step_info( - self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som - ) - logger.debug("Step info saved.") + step_info.save(self.exp_dir, self.save_screenshot, self.save_som) - if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat): - _send_chat_info(env.unwrapped.chat, action, step_info.agent_info) - logger.debug("Chat info sent.") + self.maybe_send_chat(env, action, step_info) - if action is None: - logger.debug("Agent returned None action. Ending episode.") - break - - step_info = StepInfo(step=step_info.step + 1) episode_info.append(step_info) + # --- End of (obs, action, reward) step, start a new one --- + + step_info = StepInfo(step=step_info.step + 1) logger.debug("Sending action to environment.") - step_info.from_step(env, action, obs_preprocessor=agent.obs_preprocessor) + step_info.profiling.env_start = time.time() + action_result = env.step(action) + step_info.profiling.env_stop = time.time() + step_info.add_action_result(action_result) + if agent.obs_preprocessor: + step_info.obs = agent.obs_preprocessor(step_info.obs) logger.debug("Environment stepped.") if step_info.is_done: logger.debug( f"Episode done: terminated: {step_info.terminated}, truncated: {step_info.truncated}." ) + episode_info.append(step_info) + break except Exception as e: err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}" @@ -482,9 +465,7 @@ def run(self): finally: try: if step_info is not None: - step_info.save_step_info( - self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som - ) + step_info.save(self.exp_dir, self.save_screenshot, self.save_som) except Exception as e: logger.error(f"Error while saving step info in the finally block: {e}") try: @@ -493,8 +474,7 @@ def run(self): and len(episode_info) > 0 and not (episode_info[-1].terminated or episode_info[-1].truncated) ): - e = KeyboardInterrupt("Early termination??") - err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}" + err_msg = "Last step in episode was not terminated or truncated." logger.info("Saving experiment info.") self.save_summary_info(episode_info, Path(self.exp_dir), err_msg, stack_trace) if TapeAgent is not None and isinstance(agent, TapeAgent): @@ -512,12 +492,36 @@ def run(self): except Exception as e: logger.exception(f"Error while unsetting the logger: {e}") + def create_env_and_agent(self) -> tuple[gym.Env, Agent]: + if isinstance(self.env_args, BrowserEnvArgs): + env = self.env_args.make_env(exp_dir=self.exp_dir) + logger.debug("Environment created.") + agent = self.agent_args.make_agent(actions=env.actions()) + logger.debug(f"Agent created with actions: {env.actions()}") + else: + agent = self.agent_args.make_agent() + if hasattr(agent, "set_task_name"): + agent.set_task_name(self.env_args.task_name) + logger.debug("Agent created.") + env = self.env_args.make_env( + action_mapping=agent.action_set.to_python_code, + exp_dir=self.exp_dir, + use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False), + ) + logger.debug("Environment created.") + return env, agent + + def maybe_send_chat(self, env: gym.Env, action: str, step_info: StepInfo): + if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat): + _send_chat_info(env.unwrapped.chat, action, step_info.agent_info) + logger.debug("Chat info sent.") + def _set_logger(self): # output logging traces to a log file file_handler = logging.FileHandler(self.exp_dir / "experiment.log") file_handler.setLevel(self.logging_level) # same level as console outputs formatter = logging.Formatter( - "%(asctime)s - %(process)d - %(name)s - %(levelname)s - %(message)s" + "%(asctime)s - %(process)d - %(name)s:%(lineno)d - %(levelname)s - %(message)s" ) file_handler.setFormatter(formatter) # output handler @@ -612,6 +616,7 @@ def _aggregate_episode_stats(episode_info: list[StepInfo]): stats = defaultdict(list) for step_info in episode_info: + step_info.make_stats() if step_info.stats is not None: for key, val in step_info.stats.items(): if val is None: diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 391f419c..abcc3d6c 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -17,7 +17,7 @@ from agentlab.agents.agent_args import AgentArgs from agentlab.analyze import inspect_results -from agentlab.benchmarks.abstract_env import AbstractEnvArgs +from agentlab.benchmarks.abstract_env import AbstractBenchmark, AbstractEnvArgs from agentlab.experiments import reproducibility_util as repro from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies from agentlab.experiments.launch_exp import ( @@ -33,7 +33,7 @@ def make_study( agent_args: list[AgentArgs] | AgentArgs, - benchmark: Benchmark | str, + benchmark: Benchmark | AbstractBenchmark | str, logging_level=logging.WARNING, logging_level_stdout=logging.WARNING, suffix="",