diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index b75a226..7327f60 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -2,6 +2,8 @@ from .arxiv_search import ArxivSearch, AsyncArxivSearch from .base_action import AsyncActionMixin, BaseAction, tool_api from .bing_map import AsyncBINGMap, BINGMap +from .browser_session import BrowserSession, BrowserSessionManager, BrowserTarget +from .browser_snapshot import AiSnapshotSerializer, BrowserSnapshot, SnapshotStats from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import AsyncGoogleScholar, GoogleScholar from .google_search import AsyncGoogleSearch, GoogleSearch @@ -24,6 +26,12 @@ 'AsyncBINGMap', 'ArxivSearch', 'AsyncArxivSearch', + 'BrowserSession', + 'BrowserSessionManager', + 'BrowserSnapshot', + 'BrowserTarget', + 'AiSnapshotSerializer', + 'SnapshotStats', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar', diff --git a/lagent/actions/browser_session.py b/lagent/actions/browser_session.py new file mode 100644 index 0000000..c3f3f4b --- /dev/null +++ b/lagent/actions/browser_session.py @@ -0,0 +1,346 @@ +"""Browser session manager for Lagent browser tools. + +Manages Playwright browser sessions, tabs, element ref registries, and +artifact directories (screenshots, downloads, traces) in a thread-safe way. +""" + +import os +import threading +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +try: + from playwright.sync_api import Browser, BrowserContext, Page, sync_playwright + PLAYWRIGHT_AVAILABLE = True +except ImportError: + PLAYWRIGHT_AVAILABLE = False + + +@dataclass +class BrowserTarget: + """Represents a single browser tab/page within a session.""" + + target_id: str + page: Any # playwright Page object + url: str = '' + title: str = '' + + def refresh_info(self) -> None: + """Update url/title from the live page.""" + try: + self.url = self.page.url + self.title = self.page.title() + except Exception: + pass + + +@dataclass +class BrowserSession: + """Represents a managed browser session. + + Attributes: + session_id: unique identifier for this session. + browser: Playwright Browser instance. + context: Playwright BrowserContext instance. + targets: mapping from target_id to BrowserTarget. + active_target_id: target_id of the currently active tab. + refs: mapping from ref string (e.g. ``"r1"``) to element info dict. + artifact_dir: directory path for storing screenshots/downloads/traces. + """ + + session_id: str + browser: Any # playwright Browser + context: Any # playwright BrowserContext + targets: Dict[str, 'BrowserTarget'] = field(default_factory=dict) + active_target_id: Optional[str] = None + refs: Dict[str, dict] = field(default_factory=dict) + artifact_dir: Optional[str] = None + + @property + def active_page(self) -> Optional[Any]: + """Return the Playwright Page for the active target, or ``None``.""" + if self.active_target_id and self.active_target_id in self.targets: + return self.targets[self.active_target_id].page + # Fallback: first available target + if self.targets: + return next(iter(self.targets.values())).page + return None + + def set_active_by_url(self, url: str) -> bool: + """Switch the active target to the first tab whose URL matches. + + Args: + url (str): URL (or prefix) to match. + + Returns: + bool: ``True`` if a matching target was found and activated. + """ + for tid, target in self.targets.items(): + target.refresh_info() + if target.url == url or target.url.startswith(url): + self.active_target_id = tid + return True + return False + + def set_active_by_index(self, index: int) -> bool: + """Switch the active target by zero-based tab index. + + Args: + index (int): zero-based index into :attr:`targets`. + + Returns: + bool: ``True`` if the index was valid. + """ + keys = list(self.targets.keys()) + if 0 <= index < len(keys): + self.active_target_id = keys[index] + return True + return False + + def bind_refs(self, elements: List[dict]) -> None: + """Register interactive elements as named refs. + + Args: + elements (list[dict]): element info dicts produced by the + snapshot serializer. Each dict must contain at least a + ``selector`` key that can be used to re-locate the element. + """ + self.refs.clear() + for idx, el in enumerate(elements): + ref_id = f'r{idx + 1}' + self.refs[ref_id] = el + + def resolve_ref(self, ref: str) -> Optional[dict]: + """Return element info for a ref string such as ``"r1"``. + + Args: + ref (str): ref identifier. + + Returns: + dict | None: element info dict, or ``None`` if not found. + """ + return self.refs.get(ref) + + +class BrowserSessionManager: + """Thread-safe singleton manager for Playwright browser sessions. + + Usage:: + + manager = BrowserSessionManager() + session = manager.get_or_create_session('my-session') + page = session.active_page + # ... do stuff with page ... + manager.close_session('my-session') + """ + + _instance: Optional['BrowserSessionManager'] = None + _class_lock: threading.Lock = threading.Lock() + + def __new__(cls) -> 'BrowserSessionManager': + with cls._class_lock: + if cls._instance is None: + inst = super().__new__(cls) + inst._sessions: Dict[str, BrowserSession] = {} + inst._lock = threading.Lock() + inst._playwright = None + inst._playwright_ctx = None + cls._instance = inst + return cls._instance + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _ensure_playwright(self) -> None: + if not PLAYWRIGHT_AVAILABLE: + raise RuntimeError( + 'playwright is not installed. ' + 'Install it with: pip install playwright && playwright install' + ) + if self._playwright is None: + self._playwright_ctx = sync_playwright() + self._playwright = self._playwright_ctx.start() + + def _make_artifact_dir(self, session_id: str, base: Optional[str]) -> str: + root = base or os.path.join(os.getcwd(), '.browser_artifacts') + artifact_dir = os.path.join(root, session_id) + os.makedirs(artifact_dir, exist_ok=True) + return artifact_dir + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def create_session( + self, + session_id: Optional[str] = None, + artifact_dir: Optional[str] = None, + browser_type: str = 'chromium', + headless: bool = True, + **launch_kwargs: Any, + ) -> BrowserSession: + """Launch a new browser and create a session. + + Args: + session_id (str | None): identifier for the session. A random + UUID is used when not provided. + artifact_dir (str | None): root directory for browser artifacts. + Defaults to ``/.browser_artifacts/``. + browser_type (str): Playwright browser type – ``'chromium'``, + ``'firefox'``, or ``'webkit'``. Defaults to ``'chromium'``. + headless (bool): run the browser in headless mode. Defaults to + ``True``. + **launch_kwargs: extra keyword arguments forwarded to + ``browser_type.launch()``. + + Returns: + BrowserSession: the newly created session. + + Raises: + RuntimeError: if ``playwright`` is not installed, or if + *session_id* is already in use. + """ + with self._lock: + self._ensure_playwright() + session_id = session_id or str(uuid.uuid4()) + if session_id in self._sessions: + raise RuntimeError(f"Session '{session_id}' already exists.") + + launcher = getattr(self._playwright, browser_type) + browser: Browser = launcher.launch(headless=headless, **launch_kwargs) + context: BrowserContext = browser.new_context() + page: Page = context.new_page() + + target_id = str(uuid.uuid4()) + target = BrowserTarget(target_id=target_id, page=page) + target.refresh_info() + + art_dir = self._make_artifact_dir(session_id, artifact_dir) + session = BrowserSession( + session_id=session_id, + browser=browser, + context=context, + targets={target_id: target}, + active_target_id=target_id, + artifact_dir=art_dir, + ) + self._sessions[session_id] = session + return session + + def get_session(self, session_id: str) -> Optional[BrowserSession]: + """Return an existing session by ID, or ``None`` if not found. + + Args: + session_id (str): session identifier. + + Returns: + BrowserSession | None: the session object. + """ + with self._lock: + return self._sessions.get(session_id) + + def get_or_create_session( + self, + session_id: str, + **kwargs: Any, + ) -> BrowserSession: + """Return an existing session or create a new one. + + Args: + session_id (str): session identifier. + **kwargs: forwarded to :meth:`create_session` when creating. + + Returns: + BrowserSession: existing or newly created session. + """ + with self._lock: + session = self._sessions.get(session_id) + if session is not None: + return session + return self.create_session(session_id=session_id, **kwargs) + + def list_sessions(self) -> List[str]: + """Return a list of all active session IDs. + + Returns: + list[str]: session identifiers. + """ + with self._lock: + return list(self._sessions.keys()) + + def open_tab(self, session_id: str, url: Optional[str] = None) -> str: + """Open a new tab in an existing session. + + Args: + session_id (str): session identifier. + url (str | None): optional URL to navigate the new tab to. + + Returns: + str: the new target ID. + + Raises: + KeyError: if *session_id* does not exist. + """ + with self._lock: + session = self._sessions[session_id] + page: Page = session.context.new_page() + if url: + page.goto(url) + target_id = str(uuid.uuid4()) + target = BrowserTarget(target_id=target_id, page=page) + target.refresh_info() + session.targets[target_id] = target + session.active_target_id = target_id + return target_id + + def close_tab(self, session_id: str, target_id: str) -> None: + """Close a specific tab within a session. + + Args: + session_id (str): session identifier. + target_id (str): target identifier to close. + + Raises: + KeyError: if either *session_id* or *target_id* does not exist. + """ + with self._lock: + session = self._sessions[session_id] + target = session.targets.pop(target_id) + try: + target.page.close() + except Exception: + pass + if session.active_target_id == target_id: + session.active_target_id = next(iter(session.targets), None) + + def close_session(self, session_id: str) -> None: + """Close a browser session and release all resources. + + Args: + session_id (str): session identifier. No-op if not found. + """ + with self._lock: + session = self._sessions.pop(session_id, None) + if session is None: + return + try: + session.browser.close() + except Exception: + pass + + def close_all(self) -> None: + """Close all sessions and stop the Playwright process.""" + with self._lock: + session_ids = list(self._sessions.keys()) + for sid in session_ids: + self.close_session(sid) + with self._lock: + if self._playwright is not None: + try: + self._playwright_ctx.stop() + except Exception: + pass + self._playwright = None + self._playwright_ctx = None diff --git a/lagent/actions/browser_snapshot.py b/lagent/actions/browser_snapshot.py new file mode 100644 index 0000000..8293896 --- /dev/null +++ b/lagent/actions/browser_snapshot.py @@ -0,0 +1,533 @@ +"""Browser snapshot action for Lagent. + +Provides :class:`BrowserSnapshot`, a Lagent action that captures a +model-friendly text representation of the currently active browser page, +together with an optional screenshot artifact. Interactive elements are +registered as *refs* (``r1``, ``r2``, …) that later browser-interaction +actions can resolve back to DOM nodes. + +The :class:`AiSnapshotSerializer` helper converts raw Playwright page data +into a structured text snapshot and collects emission statistics. +""" + +import json +import os +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +from lagent.actions.base_action import BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +try: + from .browser_session import PLAYWRIGHT_AVAILABLE, BrowserSession, BrowserSessionManager +except ImportError: + PLAYWRIGHT_AVAILABLE = False + BrowserSessionManager = None # type: ignore[assignment,misc] + BrowserSession = None # type: ignore[assignment,misc] + +# --------------------------------------------------------------------------- +# JavaScript snippet executed inside the browser to collect element metadata. +# --------------------------------------------------------------------------- +_COLLECT_ELEMENTS_JS = """ +() => { + const SELECTORS = [ + 'a[href]', + 'button:not([disabled])', + 'input:not([disabled])', + 'select:not([disabled])', + 'textarea:not([disabled])', + '[role="button"]:not([disabled])', + '[role="link"]', + '[role="checkbox"]', + '[role="radio"]', + '[role="menuitem"]', + '[role="option"]', + '[role="tab"]', + '[role="combobox"]', + '[contenteditable="true"]', + ].join(', '); + + function isVisible(el) { + if (!el) return false; + const rect = el.getBoundingClientRect(); + if (rect.width === 0 && rect.height === 0) return false; + const style = window.getComputedStyle(el); + return style.display !== 'none' + && style.visibility !== 'hidden' + && parseFloat(style.opacity) > 0; + } + + function getText(el) { + const t = (el.innerText || el.textContent || '').trim(); + return t.slice(0, 120); + } + + function getLabel(el) { + const id = el.getAttribute('id'); + if (id) { + const lbl = document.querySelector(`label[for="${id}"]`); + if (lbl) return (lbl.innerText || '').trim(); + } + return el.getAttribute('aria-label') || ''; + } + + const seen = new Set(); + const results = []; + document.querySelectorAll(SELECTORS).forEach(el => { + if (!isVisible(el) || seen.has(el)) return; + seen.add(el); + + const tag = el.tagName.toLowerCase(); + const type = el.getAttribute('type') || ''; + const role = el.getAttribute('role') || ''; + let text = getText(el); + if (!text) text = getLabel(el); + if (!text) text = el.getAttribute('aria-label') || ''; + if (!text) text = el.getAttribute('placeholder') || ''; + if (!text) text = el.getAttribute('value') || ''; + if (!text && tag === 'input') text = el.getAttribute('name') || ''; + + const info = { + tag, + type, + role, + text: text.slice(0, 100), + href: el.getAttribute('href') || '', + placeholder: el.getAttribute('placeholder') || '', + name: el.getAttribute('name') || '', + value: (el.tagName === 'INPUT' || el.tagName === 'TEXTAREA') + ? (el.value || '') : '', + options: [], + }; + + if (tag === 'select') { + info.options = Array.from(el.options).map(o => o.text.trim()); + } + + results.push(info); + }); + return results; +} +""" + + +# --------------------------------------------------------------------------- +# Snapshot statistics dataclass +# --------------------------------------------------------------------------- + +@dataclass +class SnapshotStats: + """Statistics emitted alongside a browser snapshot. + + Attributes: + lines: number of lines in the page-text section. + chars: total characters in the page-text section. + refs: number of interactive refs registered. + interactive: number of interactive elements found on the page. + """ + + lines: int = 0 + chars: int = 0 + refs: int = 0 + interactive: int = 0 + + +# --------------------------------------------------------------------------- +# AI snapshot serializer +# --------------------------------------------------------------------------- + +class AiSnapshotSerializer: + """Convert Playwright page content into a model-friendly text snapshot. + + The output format is:: + + URL: + Title: + + === INTERACTIVE ELEMENTS === + [r1] LINK "Home" href="/home" + [r2] BUTTON "Submit" + [r3] INPUT text name="q" placeholder="Search..." + [r4] SELECT options=["Option A","Option B"] + + === PAGE TEXT === + <visible page text, truncated to max_text_chars> + + --- stats: lines=42 chars=1234 refs=4 interactive=4 --- + + Args: + max_total_chars (int): hard cap on the total snapshot length + (excluding the stats line). Defaults to ``20000``. + max_text_chars (int): maximum characters for the page-text section. + Defaults to ``10000``. + max_refs (int): maximum number of interactive element refs to emit. + Defaults to ``100``. + """ + + def __init__( + self, + max_total_chars: int = 20_000, + max_text_chars: int = 10_000, + max_refs: int = 100, + ) -> None: + self.max_total_chars = max_total_chars + self.max_text_chars = max_text_chars + self.max_refs = max_refs + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def serialize( + self, + page: Any, + existing_refs: Optional[Dict[str, dict]] = None, + ) -> Tuple[str, SnapshotStats, List[dict]]: + """Produce a text snapshot of *page*. + + Args: + page: a Playwright ``Page`` object. + existing_refs (dict | None): not used directly; reserved for + future incremental diffing. + + Returns: + tuple: ``(snapshot_text, stats, elements)`` where *elements* is + the list of interactive element dicts that should be bound as + refs in the session. + """ + url = page.url + try: + title = page.title() + except Exception: + title = '' + + elements = self._extract_interactive_elements(page) + capped_elements = elements[: self.max_refs] + + # Build interactive section + interactive_lines: List[str] = [] + for idx, el in enumerate(capped_elements): + ref_id = f'r{idx + 1}' + interactive_lines.append(self._format_element(el, ref_id)) + + # Build page-text section + try: + raw_text = page.inner_text('body') or '' + except Exception: + raw_text = '' + page_text = self._truncate(raw_text, self.max_text_chars) + + stats = SnapshotStats( + lines=len(page_text.splitlines()), + chars=len(page_text), + refs=len(capped_elements), + interactive=len(elements), + ) + + # Assemble snapshot + parts: List[str] = [ + f'URL: {url}', + f'Title: {title}', + '', + ] + if interactive_lines: + parts += ['=== INTERACTIVE ELEMENTS ==='] + interactive_lines + [''] + if page_text: + parts += ['=== PAGE TEXT ===', page_text, ''] + parts.append( + f'--- stats: lines={stats.lines} chars={stats.chars} ' + f'refs={stats.refs} interactive={stats.interactive} ---' + ) + + snapshot = '\n'.join(parts) + snapshot = self._truncate(snapshot, self.max_total_chars) + + return snapshot, stats, capped_elements + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _extract_interactive_elements(self, page: Any) -> List[dict]: + """Run :data:`_COLLECT_ELEMENTS_JS` in the page and return results.""" + try: + return page.evaluate(_COLLECT_ELEMENTS_JS) or [] + except Exception: + return [] + + def _format_element(self, el: dict, ref_id: str) -> str: + """Render a single element dict as a compact ref line.""" + tag = el.get('tag', '').upper() + role = el.get('role', '').upper() + el_type = el.get('type', '') + text = el.get('text', '').strip() + href = el.get('href', '') + placeholder = el.get('placeholder', '') + name = el.get('name', '') + options = el.get('options', []) + + # Determine display kind + if tag == 'A' or role == 'LINK': + kind = 'LINK' + elif tag == 'BUTTON' or role == 'BUTTON': + kind = 'BUTTON' + elif tag == 'INPUT': + kind = f'INPUT {el_type}' if el_type else 'INPUT' + elif tag == 'SELECT': + kind = 'SELECT' + elif tag == 'TEXTAREA': + kind = 'TEXTAREA' + else: + kind = role or tag or 'ELEMENT' + + parts = [f'[{ref_id}]', kind] + if text: + parts.append(f'"{text}"') + if href: + parts.append(f'href="{href}"') + if placeholder: + parts.append(f'placeholder="{placeholder}"') + if name: + parts.append(f'name="{name}"') + if options: + opts_str = json.dumps(options[:10], ensure_ascii=False) + parts.append(f'options={opts_str}') + + return ' '.join(parts) + + @staticmethod + def _truncate(text: str, max_chars: int) -> str: + """Truncate *text* to *max_chars* characters deterministically.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + '\n[... truncated ...]' + + +# --------------------------------------------------------------------------- +# BrowserSnapshot action +# --------------------------------------------------------------------------- + +class BrowserSnapshot(BaseAction): + """Capture a model-friendly snapshot of the active browser page. + + This action connects to a managed + :class:`~lagent.actions.browser_session.BrowserSession`, extracts visible + text and interactive elements from the current page, registers them as + named refs (``r1``, ``r2``, …), and optionally saves a screenshot. + + Args: + artifact_dir (str | None): root directory for browser artifacts + (screenshots, downloads, traces). Defaults to + ``<cwd>/.browser_artifacts``. + max_total_chars (int): hard cap on snapshot length. + Defaults to ``20000``. + max_text_chars (int): cap on the page-text section. + Defaults to ``10000``. + max_refs (int): maximum interactive refs to include. + Defaults to ``100``. + description (dict | None): custom tool description. + parser (Type[BaseParser]): parser class. Defaults to + :class:`~lagent.actions.parser.JsonParser`. + + Example:: + + from lagent.actions.browser_snapshot import BrowserSnapshot + from lagent.actions.browser_session import BrowserSessionManager + + manager = BrowserSessionManager() + session = manager.create_session('demo') + session.active_page.goto('https://example.com') + + snap = BrowserSnapshot() + result = snap.run(session_id='demo') + print(result['snapshot']) + """ + + def __init__( + self, + artifact_dir: Optional[str] = None, + max_total_chars: int = 20_000, + max_text_chars: int = 10_000, + max_refs: int = 100, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + self._artifact_dir = artifact_dir + self._max_total_chars = max_total_chars + self._max_text_chars = max_text_chars + self._max_refs = max_refs + if PLAYWRIGHT_AVAILABLE and BrowserSessionManager is not None: + self._session_manager: Optional[BrowserSessionManager] = BrowserSessionManager() + else: + self._session_manager = None + self._serializer = AiSnapshotSerializer( + max_total_chars=max_total_chars, + max_text_chars=max_text_chars, + max_refs=max_refs, + ) + super().__init__(description, parser) + + # ------------------------------------------------------------------ + # Tool API + # ------------------------------------------------------------------ + + @tool_api(explode_return=True) + def run( + self, + session_id: str, + target: Optional[str] = None, + include_image: bool = False, + max_total_chars: Optional[int] = None, + max_text_chars: Optional[int] = None, + max_refs: Optional[int] = None, + ) -> dict: + """Capture a snapshot of the current browser page. + + Args: + session_id (str): browser session identifier; a new session is + created automatically if *session_id* is not yet known. + target (str): select a tab by zero-based index (e.g. ``"0"``) or + by matching URL prefix. Omit to use the currently active tab. + include_image (bool): when ``True``, a PNG screenshot is saved to + the session's artifact directory and its path is returned. + max_total_chars (int): per-call override for the total snapshot + character limit. + max_text_chars (int): per-call override for the page-text section + character limit. + max_refs (int): per-call override for the maximum number of + interactive refs. + + Returns: + dict: snapshot result + * snapshot: model-friendly text representation of the page + * url: URL of the active page + * title: title of the active page + * stats: dict with keys ``lines``, ``chars``, ``refs``, + ``interactive`` + * screenshot_path: absolute path to screenshot PNG, or ``""`` + when *include_image* is ``False`` + """ + if not PLAYWRIGHT_AVAILABLE or self._session_manager is None: + return ActionReturn( + args={'session_id': session_id}, + type=self.name, + errmsg=( + 'playwright is not installed. ' + 'Install with: pip install playwright && playwright install' + ), + state=ActionStatusCode.API_ERROR, + ) + + # Resolve (or create) session + try: + session: BrowserSession = self._session_manager.get_or_create_session( + session_id, + artifact_dir=self._artifact_dir, + ) + except Exception as exc: + return ActionReturn( + args={'session_id': session_id}, + type=self.name, + errmsg=f'Failed to get/create session: {exc}', + state=ActionStatusCode.API_ERROR, + ) + + # Optionally switch active tab + if target is not None: + self._switch_target(session, target) + + page = session.active_page + if page is None: + return ActionReturn( + args={'session_id': session_id}, + type=self.name, + errmsg='No active page found in session.', + state=ActionStatusCode.API_ERROR, + ) + + # Build per-call serializer (use instance defaults when not overridden) + serializer = AiSnapshotSerializer( + max_total_chars=max_total_chars or self._max_total_chars, + max_text_chars=max_text_chars or self._max_text_chars, + max_refs=max_refs or self._max_refs, + ) + + try: + snapshot_text, stats, elements = serializer.serialize( + page, existing_refs=session.refs + ) + except Exception as exc: + return ActionReturn( + args={'session_id': session_id}, + type=self.name, + errmsg=f'Serialization failed: {exc}', + state=ActionStatusCode.API_ERROR, + ) + + # Bind refs into session state + session.bind_refs(elements) + + # Optionally take a screenshot + screenshot_path = '' + if include_image: + screenshot_path = self._capture_screenshot(session, page) + + # Refresh active target info + if session.active_target_id and session.active_target_id in session.targets: + session.targets[session.active_target_id].refresh_info() + + current_url = page.url + try: + current_title = page.title() + except Exception: + current_title = '' + + return { + 'snapshot': snapshot_text, + 'url': current_url, + 'title': current_title, + 'stats': asdict(stats), + 'screenshot_path': screenshot_path, + } + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _switch_target(self, session: 'BrowserSession', target: str) -> None: + """Attempt to switch the session's active tab. + + *target* is tried as a zero-based integer index first; if that fails + it is matched against tab URLs. + + Args: + session: the active :class:`BrowserSession`. + target (str): tab selector. + """ + try: + index = int(target) + session.set_active_by_index(index) + except (ValueError, TypeError): + session.set_active_by_url(target) + + def _capture_screenshot(self, session: 'BrowserSession', page: Any) -> str: + """Save a PNG screenshot and return the absolute file path. + + Args: + session: the active :class:`BrowserSession`. + page: Playwright ``Page`` object. + + Returns: + str: absolute path to the screenshot file, or ``""`` on failure. + """ + try: + import time + art_dir = session.artifact_dir or os.getcwd() + os.makedirs(art_dir, exist_ok=True) + timestamp = int(time.time() * 1000) + path = os.path.join(art_dir, f'screenshot_{timestamp}.png') + page.screenshot(path=path, full_page=False) + return os.path.abspath(path) + except Exception: + return '' diff --git a/tests/test_actions/test_browser_snapshot.py b/tests/test_actions/test_browser_snapshot.py new file mode 100644 index 0000000..66c8322 --- /dev/null +++ b/tests/test_actions/test_browser_snapshot.py @@ -0,0 +1,521 @@ +"""Tests for browser_session and browser_snapshot modules. + +These tests use unittest.mock so that Playwright does not need to be installed +in CI. The mocked page objects replicate the minimal Playwright API surface +used by :class:`AiSnapshotSerializer` and :class:`BrowserSnapshot`. +""" + +import os +import sys +import types +import unittest +from dataclasses import asdict +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# Helper: build a minimal fake playwright module so that the import guard in +# browser_session.py does not prevent importing in environments where the real +# playwright package is absent. +# --------------------------------------------------------------------------- + +def _make_fake_playwright_module(): + """Return a fake ``playwright.sync_api`` module with stub classes.""" + mod = types.ModuleType('playwright') + sync_mod = types.ModuleType('playwright.sync_api') + + class _FakeAPI: + """Minimal stub for sync_playwright context.""" + + def start(self): + return _FakePlaywright() + + def stop(self): + pass + + class _FakePlaywright: + chromium = MagicMock() + + sync_mod.sync_playwright = lambda: _FakeAPI() + sync_mod.Browser = object + sync_mod.BrowserContext = object + sync_mod.Page = object + + mod.sync_api = sync_mod + sys.modules.setdefault('playwright', mod) + sys.modules.setdefault('playwright.sync_api', sync_mod) + return sync_mod + + +_make_fake_playwright_module() + +# Now import the modules under test +from lagent.actions.browser_session import ( # noqa: E402 + BrowserSession, + BrowserSessionManager, + BrowserTarget, +) +from lagent.actions.browser_snapshot import ( # noqa: E402 + AiSnapshotSerializer, + BrowserSnapshot, + SnapshotStats, +) +from lagent.schema import ActionReturn, ActionStatusCode # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _fake_page( + url='https://example.com', + title='Example', + body_text='Hello world\nLine two', + elements=None, +): + """Build a MagicMock that mimics the Playwright Page API.""" + page = MagicMock() + page.url = url + page.title.return_value = title + page.inner_text.return_value = body_text + page.evaluate.return_value = elements if elements is not None else [] + return page + + +# --------------------------------------------------------------------------- +# SnapshotStats tests +# --------------------------------------------------------------------------- + +class TestSnapshotStats(unittest.TestCase): + + def test_defaults(self): + stats = SnapshotStats() + self.assertEqual(stats.lines, 0) + self.assertEqual(stats.chars, 0) + self.assertEqual(stats.refs, 0) + self.assertEqual(stats.interactive, 0) + + def test_asdict(self): + stats = SnapshotStats(lines=3, chars=42, refs=2, interactive=5) + d = asdict(stats) + self.assertEqual(d, {'lines': 3, 'chars': 42, 'refs': 2, 'interactive': 5}) + + +# --------------------------------------------------------------------------- +# AiSnapshotSerializer tests +# --------------------------------------------------------------------------- + +class TestAiSnapshotSerializer(unittest.TestCase): + + def setUp(self): + self.serializer = AiSnapshotSerializer( + max_total_chars=5000, + max_text_chars=2000, + max_refs=10, + ) + + def test_empty_page(self): + page = _fake_page(body_text='', elements=[]) + snapshot, stats, elements = self.serializer.serialize(page) + self.assertIn('URL: https://example.com', snapshot) + self.assertIn('Title: Example', snapshot) + self.assertEqual(stats.refs, 0) + self.assertEqual(stats.interactive, 0) + self.assertEqual(elements, []) + + def test_page_text_included(self): + page = _fake_page(body_text='Welcome to the site\nSecond line') + snapshot, stats, _ = self.serializer.serialize(page) + self.assertIn('=== PAGE TEXT ===', snapshot) + self.assertIn('Welcome to the site', snapshot) + self.assertGreater(stats.lines, 0) + self.assertGreater(stats.chars, 0) + + def test_interactive_elements_included(self): + elements = [ + {'tag': 'a', 'type': '', 'role': '', 'text': 'Home', + 'href': '/home', 'placeholder': '', 'name': '', 'value': '', + 'options': []}, + {'tag': 'button', 'type': '', 'role': '', 'text': 'Submit', + 'href': '', 'placeholder': '', 'name': '', 'value': '', + 'options': []}, + ] + page = _fake_page(elements=elements) + snapshot, stats, out_elements = self.serializer.serialize(page) + self.assertIn('=== INTERACTIVE ELEMENTS ===', snapshot) + self.assertIn('[r1]', snapshot) + self.assertIn('[r2]', snapshot) + self.assertIn('LINK', snapshot) + self.assertIn('BUTTON', snapshot) + self.assertEqual(stats.refs, 2) + self.assertEqual(stats.interactive, 2) + self.assertEqual(len(out_elements), 2) + + def test_max_refs_cap(self): + serializer = AiSnapshotSerializer(max_refs=2) + elements = [ + {'tag': 'button', 'type': '', 'role': '', 'text': f'B{i}', + 'href': '', 'placeholder': '', 'name': '', 'value': '', + 'options': []} + for i in range(5) + ] + page = _fake_page(elements=elements) + _, stats, out_elements = serializer.serialize(page) + self.assertEqual(stats.refs, 2) + self.assertEqual(stats.interactive, 5) + self.assertEqual(len(out_elements), 2) + + def test_text_truncation(self): + long_text = 'x' * 5000 + serializer = AiSnapshotSerializer(max_text_chars=100, max_total_chars=5000) + page = _fake_page(body_text=long_text) + snapshot, stats, _ = serializer.serialize(page) + self.assertIn('[... truncated ...]', snapshot) + # stats.chars counts the length of the (truncated) page_text section, + # which is max_text_chars + the truncation marker string. + truncation_marker = '\n[... truncated ...]' + self.assertLessEqual(stats.chars, 100 + len(truncation_marker)) + + def test_total_chars_truncation(self): + long_text = 'y' * 20000 + serializer = AiSnapshotSerializer(max_total_chars=500, max_text_chars=10000) + page = _fake_page(body_text=long_text) + snapshot, _, _ = serializer.serialize(page) + self.assertLessEqual(len(snapshot), 500 + len('\n[... truncated ...]')) + + def test_stats_line_present(self): + page = _fake_page() + snapshot, _, _ = self.serializer.serialize(page) + self.assertIn('--- stats:', snapshot) + + def test_select_element_options(self): + elements = [ + {'tag': 'select', 'type': '', 'role': '', 'text': '', + 'href': '', 'placeholder': '', 'name': 'color', 'value': '', + 'options': ['Red', 'Green', 'Blue']}, + ] + page = _fake_page(elements=elements) + snapshot, _, _ = self.serializer.serialize(page) + self.assertIn('SELECT', snapshot) + self.assertIn('Red', snapshot) + + def test_input_element_formatting(self): + elements = [ + {'tag': 'input', 'type': 'text', 'role': '', 'text': '', + 'href': '', 'placeholder': 'Search...', 'name': 'q', 'value': '', + 'options': []}, + ] + page = _fake_page(elements=elements) + snapshot, _, _ = self.serializer.serialize(page) + self.assertIn('INPUT text', snapshot) + self.assertIn('placeholder="Search..."', snapshot) + + +# --------------------------------------------------------------------------- +# BrowserTarget / BrowserSession tests +# --------------------------------------------------------------------------- + +class TestBrowserTarget(unittest.TestCase): + + def test_refresh_info(self): + page = _fake_page(url='https://test.com', title='Test') + target = BrowserTarget(target_id='t1', page=page) + self.assertEqual(target.url, '') # not yet refreshed + target.refresh_info() + self.assertEqual(target.url, 'https://test.com') + self.assertEqual(target.title, 'Test') + + def test_refresh_info_error_suppressed(self): + page = MagicMock() + page.url = 'https://ok.com' + page.title.side_effect = Exception('disconnected') + target = BrowserTarget(target_id='t2', page=page) + # Should not raise even if title() throws + target.refresh_info() + self.assertEqual(target.url, 'https://ok.com') + + +class TestBrowserSession(unittest.TestCase): + + def _make_session(self): + page = _fake_page() + target = BrowserTarget(target_id='t1', page=page) + session = BrowserSession( + session_id='s1', + browser=MagicMock(), + context=MagicMock(), + targets={'t1': target}, + active_target_id='t1', + ) + return session, page + + def test_active_page_returns_correct_page(self): + session, page = self._make_session() + self.assertIs(session.active_page, page) + + def test_active_page_fallback(self): + session, page = self._make_session() + session.active_target_id = None + # Should fall back to first target + self.assertIs(session.active_page, page) + + def test_active_page_none_when_empty(self): + session = BrowserSession( + session_id='empty', + browser=MagicMock(), + context=MagicMock(), + ) + self.assertIsNone(session.active_page) + + def test_set_active_by_index(self): + page1 = _fake_page(url='https://a.com') + page2 = _fake_page(url='https://b.com') + target1 = BrowserTarget(target_id='t1', page=page1, url='https://a.com') + target2 = BrowserTarget(target_id='t2', page=page2, url='https://b.com') + session = BrowserSession( + session_id='s', + browser=MagicMock(), + context=MagicMock(), + targets={'t1': target1, 't2': target2}, + active_target_id='t1', + ) + result = session.set_active_by_index(1) + self.assertTrue(result) + self.assertEqual(session.active_target_id, 't2') + + def test_set_active_by_index_out_of_range(self): + session, _ = self._make_session() + result = session.set_active_by_index(99) + self.assertFalse(result) + + def test_set_active_by_url(self): + page1 = _fake_page(url='https://a.com/page') + target1 = BrowserTarget(target_id='t1', page=page1, url='https://a.com/page') + session = BrowserSession( + session_id='s', + browser=MagicMock(), + context=MagicMock(), + targets={'t1': target1}, + active_target_id='t1', + ) + result = session.set_active_by_url('https://a.com') + self.assertTrue(result) + + def test_bind_and_resolve_refs(self): + session, _ = self._make_session() + elements = [ + {'tag': 'a', 'type': '', 'role': '', 'text': 'Home', + 'href': '/home', 'placeholder': '', 'name': '', 'value': '', + 'options': []}, + {'tag': 'button', 'type': '', 'role': '', 'text': 'Submit', + 'href': '', 'placeholder': '', 'name': '', 'value': '', + 'options': []}, + ] + session.bind_refs(elements) + self.assertEqual(len(session.refs), 2) + self.assertEqual(session.resolve_ref('r1'), elements[0]) + self.assertEqual(session.resolve_ref('r2'), elements[1]) + self.assertIsNone(session.resolve_ref('r99')) + + def test_bind_refs_clears_previous(self): + session, _ = self._make_session() + session.bind_refs([{'tag': 'a', 'type': '', 'role': '', 'text': 'Old', + 'href': '', 'placeholder': '', 'name': '', + 'value': '', 'options': []}]) + self.assertIn('r1', session.refs) + session.bind_refs([]) + self.assertEqual(session.refs, {}) + + +# --------------------------------------------------------------------------- +# BrowserSnapshot action tests (Playwright mocked) +# --------------------------------------------------------------------------- + +class TestBrowserSnapshot(unittest.TestCase): + """Test BrowserSnapshot using a fully mocked BrowserSessionManager.""" + + def _make_snapshot_action(self): + """Return a BrowserSnapshot with a mocked session manager.""" + snap = BrowserSnapshot() + mock_manager = MagicMock() + snap._session_manager = mock_manager + return snap, mock_manager + + def _make_mock_session(self, page=None): + """Create a mock BrowserSession with a working active_page.""" + if page is None: + page = _fake_page() + session = MagicMock(spec=BrowserSession) + session.active_page = page + session.active_target_id = 't1' + session.targets = {'t1': MagicMock()} + session.refs = {} + session.artifact_dir = '/tmp/artifacts' + return session + + def test_run_returns_snapshot(self): + snap, mock_manager = self._make_snapshot_action() + page = _fake_page( + url='https://example.com', + title='Ex', + body_text='Hello', + elements=[], + ) + session = self._make_mock_session(page) + mock_manager.get_or_create_session.return_value = session + + result = snap.run(session_id='s1') + # result is a dict when successful + self.assertIsInstance(result, dict) + self.assertIn('snapshot', result) + self.assertIn('https://example.com', result['url']) + self.assertIn('stats', result) + + def test_run_binds_refs(self): + snap, mock_manager = self._make_snapshot_action() + elements = [ + {'tag': 'button', 'type': '', 'role': '', 'text': 'Go', + 'href': '', 'placeholder': '', 'name': '', 'value': '', + 'options': []}, + ] + page = _fake_page(elements=elements) + session = self._make_mock_session(page) + mock_manager.get_or_create_session.return_value = session + + snap.run(session_id='s1') + session.bind_refs.assert_called_once() + + def test_run_target_index(self): + snap, mock_manager = self._make_snapshot_action() + page = _fake_page() + session = self._make_mock_session(page) + session.set_active_by_index = MagicMock(return_value=True) + mock_manager.get_or_create_session.return_value = session + + snap.run(session_id='s1', target='0') + session.set_active_by_index.assert_called_once_with(0) + + def test_run_target_url(self): + snap, mock_manager = self._make_snapshot_action() + page = _fake_page() + session = self._make_mock_session(page) + session.set_active_by_index = MagicMock(side_effect=ValueError) + session.set_active_by_url = MagicMock(return_value=True) + mock_manager.get_or_create_session.return_value = session + + snap.run(session_id='s1', target='https://example.com') + session.set_active_by_url.assert_called_once_with('https://example.com') + + def test_run_no_active_page(self): + snap, mock_manager = self._make_snapshot_action() + session = MagicMock(spec=BrowserSession) + session.active_page = None + mock_manager.get_or_create_session.return_value = session + + result = snap.run(session_id='s1') + # Should return an ActionReturn with API_ERROR + self.assertIsInstance(result, ActionReturn) + self.assertEqual(result.state, ActionStatusCode.API_ERROR) + + def test_run_session_creation_error(self): + snap, mock_manager = self._make_snapshot_action() + mock_manager.get_or_create_session.side_effect = RuntimeError('boom') + + result = snap.run(session_id='s1') + self.assertIsInstance(result, ActionReturn) + self.assertEqual(result.state, ActionStatusCode.API_ERROR) + self.assertIn('boom', result.errmsg) + + def test_run_include_image(self): + snap, mock_manager = self._make_snapshot_action() + page = _fake_page() + page.screenshot = MagicMock() + session = self._make_mock_session(page) + session.artifact_dir = '/tmp/test_artifacts' + mock_manager.get_or_create_session.return_value = session + + with patch('os.makedirs'), patch('os.path.abspath', return_value='/abs/path.png'): + result = snap.run(session_id='s1', include_image=True) + + if isinstance(result, dict): + # screenshot_path may be empty if abspath mock didn't fully wire through + self.assertIn('screenshot_path', result) + + def test_run_per_call_overrides(self): + snap, mock_manager = self._make_snapshot_action() + page = _fake_page(body_text='x' * 5000) + session = self._make_mock_session(page) + mock_manager.get_or_create_session.return_value = session + + result = snap.run( + session_id='s1', + max_total_chars=200, + max_text_chars=50, + max_refs=5, + ) + if isinstance(result, dict): + snapshot = result['snapshot'] + # With a 200-char cap the snapshot should be truncated + self.assertLessEqual(len(snapshot), 200 + len('\n[... truncated ...]')) + + def test_playwright_unavailable(self): + """BrowserSnapshot.run should return API_ERROR when playwright missing.""" + snap = BrowserSnapshot() + snap._session_manager = None + + # Patch the module-level flag + import lagent.actions.browser_snapshot as bsmod + original = bsmod.PLAYWRIGHT_AVAILABLE + bsmod.PLAYWRIGHT_AVAILABLE = False + try: + result = snap.run(session_id='s1') + self.assertIsInstance(result, ActionReturn) + self.assertEqual(result.state, ActionStatusCode.API_ERROR) + self.assertIn('playwright', result.errmsg.lower()) + finally: + bsmod.PLAYWRIGHT_AVAILABLE = original + + def test_description_is_dict(self): + snap = BrowserSnapshot() + desc = snap.description + self.assertIsInstance(desc, dict) + self.assertIn('name', desc) + self.assertEqual(desc['name'], 'BrowserSnapshot') + + def test_action_call_interface(self): + """BrowserSnapshot should work through the standard __call__ interface.""" + snap = BrowserSnapshot() + mock_manager = MagicMock() + snap._session_manager = mock_manager + + page = _fake_page() + session = self._make_mock_session(page) + mock_manager.get_or_create_session.return_value = session + + import json + action_return = snap(json.dumps({'session_id': 's1'})) + self.assertIsNotNone(action_return) + self.assertEqual(action_return.state, ActionStatusCode.SUCCESS) + + +# --------------------------------------------------------------------------- +# BrowserSessionManager tests (no real browser; manager calls are mocked) +# --------------------------------------------------------------------------- + +class TestBrowserSessionManager(unittest.TestCase): + """Light structural tests that do not launch a real browser.""" + + def test_singleton(self): + m1 = BrowserSessionManager() + m2 = BrowserSessionManager() + self.assertIs(m1, m2) + + def test_list_sessions_empty_initially(self): + manager = BrowserSessionManager() + # Only check that list_sessions returns a list (may have sessions from + # other tests since the manager is a singleton). + self.assertIsInstance(manager.list_sessions(), list) + + +if __name__ == '__main__': + unittest.main()