Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import logging
import re
from collections.abc import Callable, Coroutine, Sequence
from collections.abc import Awaitable, Coroutine, Sequence

from typing_extensions import NotRequired, Required, Self, Unpack

Expand All @@ -39,6 +39,9 @@

HttpPayload = bytes

DeferredCleanupCallback = Callable[[], 'Awaitable[Any]']
"""An async callback to be called after request processing completes (including error handlers)."""

RequestTransformAction = Literal['skip', 'unchanged']

EnqueueStrategy = Literal['all', 'same-domain', 'same-hostname', 'same-origin']
Expand Down Expand Up @@ -661,6 +664,9 @@ class BasicCrawlingContext:
log: logging.Logger
"""Logger instance."""

register_deferred_cleanup: Callable[[DeferredCleanupCallback], None]
"""Register an async callback to be called after request processing completes (including error handlers)."""

async def get_snapshot(self) -> PageSnapshot:
"""Get snapshot of crawled page."""
return PageSnapshot()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ async def get_input_state(
get_key_value_store=result.get_key_value_store,
use_state=use_state_function,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
)

try:
Expand Down
10 changes: 10 additions & 0 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,8 @@ async def __run_task_function(self) -> None:
proxy_info = await self._get_proxy_info(request, session)
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request)

deferred_cleanup: list[Callable[[], Awaitable[None]]] = []

context = BasicCrawlingContext(
request=result.request,
session=session,
Expand All @@ -1423,6 +1425,7 @@ async def __run_task_function(self) -> None:
get_key_value_store=result.get_key_value_store,
use_state=self._use_state,
log=self._logger,
register_deferred_cleanup=deferred_cleanup.append,
)
self._context_result_map[context] = result

Expand Down Expand Up @@ -1509,6 +1512,13 @@ async def __run_task_function(self) -> None:
)
raise

finally:
for cleanup in deferred_cleanup:
try:
await cleanup()
except Exception: # noqa: PERF203
self._logger.exception('Error in deferred cleanup')

async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
context.request.state = RequestState.BEFORE_NAV
await self._context_pipeline(
Expand Down
120 changes: 62 additions & 58 deletions src/crawlee/crawlers/_playwright/_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def _open_page(
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=crawlee_page.page,
block_requests=partial(block_requests, page=crawlee_page.page),
goto_options=GotoOptions(**self._goto_options),
Expand Down Expand Up @@ -296,63 +297,69 @@ async def _navigate(
The enhanced crawling context with the Playwright-specific features (page, response, enqueue_links,
infinite_scroll and block_requests).
"""
async with context.page:
if context.session:
session_cookies = context.session.cookies.get_cookies_as_playwright_format()
await self._update_cookies(context.page, session_cookies)

if context.request.headers:
await context.page.set_extra_http_headers(context.request.headers.model_dump())
# Navigate to the URL and get response.
if context.request.method != 'GET':
# Call the notification only once
warnings.warn(
'Using other request methods than GET or adding payloads has a high impact on performance'
' in recent versions of Playwright. Use only when necessary.',
category=UserWarning,
stacklevel=2,
)
# Enter the page context manager, but defer its cleanup (page.close()) so the page stays open
# during error handler execution.
await context.page.__aenter__()

route_handler = self._prepare_request_interceptor(
method=context.request.method,
headers=context.request.headers,
payload=context.request.payload,
)
context.register_deferred_cleanup(lambda: context.page.__aexit__(None, None, None))

# Set route_handler only for current request
await context.page.route(context.request.url, route_handler)
if context.session:
session_cookies = context.session.cookies.get_cookies_as_playwright_format()
await self._update_cookies(context.page, session_cookies)

if context.request.headers:
await context.page.set_extra_http_headers(context.request.headers.model_dump())
# Navigate to the URL and get response.
if context.request.method != 'GET':
# Call the notification only once
warnings.warn(
'Using other request methods than GET or adding payloads has a high impact on performance'
' in recent versions of Playwright. Use only when necessary.',
category=UserWarning,
stacklevel=2,
)

try:
async with self._shared_navigation_timeouts[id(context)] as remaining_timeout:
response = await context.page.goto(
context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options
)
context.request.state = RequestState.AFTER_NAV
except playwright.async_api.TimeoutError as exc:
raise asyncio.TimeoutError from exc

if response is None:
raise SessionError(f'Failed to load the URL: {context.request.url}')

# Set the loaded URL to the actual URL after redirection.
context.request.loaded_url = context.page.url

yield PlaywrightPostNavCrawlingContext(
request=context.request,
session=context.session,
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
use_state=context.use_state,
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
page=context.page,
block_requests=context.block_requests,
goto_options=context.goto_options,
response=response,
route_handler = self._prepare_request_interceptor(
method=context.request.method,
headers=context.request.headers,
payload=context.request.payload,
)

# Set route_handler only for current request
await context.page.route(context.request.url, route_handler)

try:
async with self._shared_navigation_timeouts[id(context)] as remaining_timeout:
response = await context.page.goto(
context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options
)
context.request.state = RequestState.AFTER_NAV
except playwright.async_api.TimeoutError as exc:
raise asyncio.TimeoutError from exc

if response is None:
raise SessionError(f'Failed to load the URL: {context.request.url}')

# Set the loaded URL to the actual URL after redirection.
context.request.loaded_url = context.page.url

yield PlaywrightPostNavCrawlingContext(
request=context.request,
session=context.session,
add_requests=context.add_requests,
send_request=context.send_request,
push_data=context.push_data,
use_state=context.use_state,
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=context.page,
block_requests=context.block_requests,
goto_options=context.goto_options,
response=response,
)

def _create_extract_links_function(self, context: PlaywrightPreNavCrawlingContext) -> ExtractLinksFunction:
"""Create a callback function for extracting links from context.

Expand Down Expand Up @@ -495,10 +502,10 @@ async def _execute_post_navigation_hooks(

async def _create_crawling_context(
self, context: PlaywrightPostNavCrawlingContext
) -> AsyncGenerator[PlaywrightCrawlingContext, Exception | None]:
) -> AsyncGenerator[PlaywrightCrawlingContext, None]:
extract_links = self._create_extract_links_function(context)

error = yield PlaywrightCrawlingContext(
yield PlaywrightCrawlingContext(
request=context.request,
session=context.session,
add_requests=context.add_requests,
Expand All @@ -508,6 +515,7 @@ async def _create_crawling_context(
proxy_info=context.proxy_info,
get_key_value_store=context.get_key_value_store,
log=context.log,
register_deferred_cleanup=context.register_deferred_cleanup,
page=context.page,
goto_options=context.goto_options,
response=context.response,
Expand All @@ -521,10 +529,6 @@ async def _create_crawling_context(
pw_cookies = await self._get_cookies(context.page)
context.session.cookies.set_cookies_from_playwright_format(pw_cookies)

# Collect data in case of errors, before the page object is closed.
if error:
await self.statistics.error_tracker.add(error=error, context=context, early=True)

def pre_navigation_hook(self, hook: Callable[[PlaywrightPreNavCrawlingContext], Awaitable[None]]) -> None:
"""Register a hook to be called before each navigation.

Expand Down
12 changes: 0 additions & 12 deletions src/crawlee/statistics/_error_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,19 @@ def __init__(
raise ValueError('`show_error_message` must be `True` if `show_full_message` is set to `True`')
self.show_full_message = show_full_message
self._errors: ErrorFilenameGroups = defaultdict(lambda: defaultdict(Counter))
self._early_reported_errors = set[int]()

async def add(
self,
error: Exception,
*,
context: BasicCrawlingContext | None = None,
early: bool = False,
) -> None:
"""Add an error in the statistics.

Args:
error: Error to be added to statistics.
context: Context used to collect error snapshot.
early: Flag indicating that the error is added earlier than usual to have access to resources that will be
closed before normal error collection. This prevents double reporting during normal error collection.
"""
if id(error) in self._early_reported_errors:
# Error had to be collected earlier before relevant resources are closed.
self._early_reported_errors.remove(id(error))
return

if early:
self._early_reported_errors.add(id(error))

error_group_name = error.__class__.__name__ if self.show_error_name else None
error_group_message = self._get_error_message(error)
new_error_group_message = '' # In case of wildcard similarity match
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/crawlers/_basic/test_context_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def test_calls_consumer_without_middleware() -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

await pipeline(context, consumer)
Expand Down Expand Up @@ -68,6 +69,7 @@ async def middleware_a(context: BasicCrawlingContext) -> AsyncGenerator[Enhanced
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=context.register_deferred_cleanup,
)
events.append('middleware_a_out')

Expand All @@ -85,6 +87,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=context.register_deferred_cleanup,
)
events.append('middleware_b_out')

Expand All @@ -100,6 +103,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)
await pipeline(context, consumer)

Expand All @@ -126,6 +130,7 @@ async def test_wraps_consumer_errors() -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(RequestHandlerError):
Expand Down Expand Up @@ -155,6 +160,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(ContextPipelineInitializationError):
Expand Down Expand Up @@ -187,6 +193,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)

with pytest.raises(ContextPipelineFinalizationError):
Expand Down
36 changes: 34 additions & 2 deletions tests/unit/crawlers/_playwright/test_playwright_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
service_locator,
)
from crawlee.configuration import Configuration
from crawlee.crawlers import PlaywrightCrawler
from crawlee.crawlers import (
PlaywrightCrawler,
PlaywrightCrawlingContext,
)
from crawlee.fingerprint_suite import (
DefaultFingerprintGenerator,
FingerprintGenerator,
Expand Down Expand Up @@ -49,7 +52,6 @@
from crawlee.browsers._types import BrowserType
from crawlee.crawlers import (
BasicCrawlingContext,
PlaywrightCrawlingContext,
PlaywrightPostNavCrawlingContext,
PlaywrightPreNavCrawlingContext,
)
Expand Down Expand Up @@ -1203,3 +1205,33 @@ async def post_nav_hook_2(_context: PlaywrightPostNavCrawlingContext) -> None:
'post-navigation-hook 2',
'final handler',
]


async def test_error_handler_can_access_page(server_url: URL) -> None:
"""Test that the error handler can access the Page object via PlaywrightCrawlingContext."""

crawler = PlaywrightCrawler(max_request_retries=2)

request_handler = mock.AsyncMock(side_effect=RuntimeError('Intentional crash'))
crawler.router.default_handler(request_handler)

error_handler_calls: list[str | None] = []

@crawler.error_handler
async def error_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None:
error_handler_calls.append(
await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None
)

failed_handler_calls: list[str | None] = []

@crawler.failed_request_handler
async def failed_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None:
failed_handler_calls.append(
await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None
)

await crawler.run([str(server_url / 'hello-world')])

assert error_handler_calls == [HELLO_WORLD.decode(), HELLO_WORLD.decode()]
assert failed_handler_calls == [HELLO_WORLD.decode()]
1 change: 1 addition & 0 deletions tests/unit/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, *, label: str | None) -> None:
use_state=AsyncMock(),
get_key_value_store=AsyncMock(),
log=logging.getLogger(),
register_deferred_cleanup=lambda _: None,
)


Expand Down
Loading