From 206aa14aa9e78ae5b993ae8d682c5bb37499fd06 Mon Sep 17 00:00:00 2001 From: CGW406 <13565294+cgw406@user.noreply.gitee.com> Date: Wed, 22 Apr 2026 16:21:41 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(=E9=87=8D=E8=AF=95=E6=9C=BA=E5=88=B6):?= =?UTF-8?q?=20=E6=B7=BB=E5=8A=A0=E8=87=AA=E5=8A=A8=E9=87=8D=E8=AF=95?= =?UTF-8?q?=E8=A3=85=E9=A5=B0=E5=99=A8=E5=B9=B6=E6=9B=BF=E6=8D=A2=E5=8E=9F?= =?UTF-8?q?=E6=9C=89=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加自定义的auto_retry装饰器实现网络请求自动重试功能 替换xhs和douyin平台中使用的tenacity重试机制 新增ENABLE_AUTO_RETRY配置开关控制重试功能 --- config/base_config.py | 4 + media_platform/douyin/client.py | 3 + media_platform/xhs/client.py | 6 +- media_platform/xhs/core.py | 5 +- tools/retry_decorator.py | 140 ++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 tools/retry_decorator.py diff --git a/config/base_config.py b/config/base_config.py index 961d2169a..0274b34a5 100644 --- a/config/base_config.py +++ b/config/base_config.py @@ -117,6 +117,10 @@ # 爬取间隔时间 CRAWLER_MAX_SLEEP_SEC = 2 +# 自动重试机制开关(仅针对小红书和抖音平台) +# 开启后,网络相关异常(超时、连接失败等)会自动重试3次,采用指数退避策略(1s、2s、4s) +ENABLE_AUTO_RETRY = True + from .bilibili_config import * from .xhs_config import * from .dy_config import * diff --git a/media_platform/douyin/client.py b/media_platform/douyin/client.py index b080f836b..84cac7443 100644 --- a/media_platform/douyin/client.py +++ b/media_platform/douyin/client.py @@ -29,6 +29,7 @@ from base.base_crawler import AbstractApiClient from proxy.proxy_mixin import ProxyRefreshMixin from tools import utils +from tools.retry_decorator import auto_retry from var import request_keyword_var if TYPE_CHECKING: @@ -112,6 +113,7 @@ async def __process_req_params( a_bogus = await get_a_bogus(uri, query_string, post_data, headers["User-Agent"], self.playwright_page) params["a_bogus"] = a_bogus + @auto_retry(max_retries=3, base_delay=1.0) async def request(self, method, url, **kwargs): # 每次请求前检测代理是否过期 await self._refresh_proxy_if_expired() @@ -332,6 +334,7 @@ async def get_all_user_aweme_posts(self, sec_user_id: str, callback: Optional[Ca result.extend(aweme_list) return result + @auto_retry(max_retries=3, base_delay=1.0) async def get_aweme_media(self, url: str) -> Union[bytes, None]: async with httpx.AsyncClient(proxy=self.proxy) as client: try: diff --git a/media_platform/xhs/client.py b/media_platform/xhs/client.py index f1df0de95..00d025d22 100644 --- a/media_platform/xhs/client.py +++ b/media_platform/xhs/client.py @@ -24,12 +24,12 @@ import httpx from playwright.async_api import BrowserContext, Page -from tenacity import retry, stop_after_attempt, wait_fixed import config from base.base_crawler import AbstractApiClient from proxy.proxy_mixin import ProxyRefreshMixin from tools import utils +from tools.retry_decorator import auto_retry if TYPE_CHECKING: from proxy.proxy_ip_pool import ProxyIpPool @@ -109,7 +109,7 @@ async def _pre_headers(self, url: str, params: Optional[Dict] = None, payload: O self.headers.update(headers) return self.headers - @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + @auto_retry(max_retries=3, base_delay=1.0) async def request(self, method, url, **kwargs) -> Union[str, Any]: """ Wrapper for httpx common request method, processes request response @@ -613,7 +613,7 @@ async def get_note_short_url(self, note_id: str) -> Dict: data = {"original_url": f"{self._domain}/discovery/item/{note_id}"} return await self.post(uri, data=data, return_response=True) - @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + @auto_retry(max_retries=3, base_delay=1.0) async def get_note_by_id_from_html( self, note_id: str, diff --git a/media_platform/xhs/core.py b/media_platform/xhs/core.py index 704746846..807fbafba 100644 --- a/media_platform/xhs/core.py +++ b/media_platform/xhs/core.py @@ -30,8 +30,6 @@ Playwright, async_playwright, ) -from tenacity import RetryError - import config from base.base_crawler import AbstractCrawler from model.m_xiaohongshu import NoteUrlInfo, CreatorUrlInfo @@ -39,6 +37,7 @@ from store import xhs as xhs_store from tools import utils from tools.cdp_browser import CDPBrowserManager +from tools.retry_decorator import RetryExhaustedError from var import crawler_type_var, source_keyword_var from .client import XiaoHongShuClient @@ -291,7 +290,7 @@ async def get_note_detail_async_task( try: try: note_detail = await self.xhs_client.get_note_by_id(note_id, xsec_source, xsec_token) - except RetryError: + except RetryExhaustedError: pass if not note_detail: diff --git a/tools/retry_decorator.py b/tools/retry_decorator.py new file mode 100644 index 000000000..321688446 --- /dev/null +++ b/tools/retry_decorator.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2025 relakkes@gmail.com +# +# This file is part of MediaCrawler project. +# Repository: https://github.com/NanmiCoder/MediaCrawler/blob/main/tools/retry_decorator.py +# GitHub: https://github.com/NanmiCoder +# Licensed under NON-COMMERCIAL LEARNING LICENSE 1.1 +# +# 声明:本代码仅供学习和研究目的使用。使用者应遵守以下原则: +# 1. 不得用于任何商业用途。 +# 2. 使用时应遵守目标平台的使用条款和robots.txt规则。 +# 3. 不得进行大规模爬取或对平台造成运营干扰。 +# 4. 应合理控制请求频率,避免给目标平台带来不必要的负担。 +# 5. 不得用于任何非法或不当的用途。 +# +# 详细许可条款请参阅项目根目录下的LICENSE文件。 +# 使用本代码即表示您同意遵守上述原则和LICENSE中的所有条款。 + +import asyncio +from functools import wraps +from typing import Callable, Type, Tuple + +import httpx + +import config +from tools import utils + + +class RetryExhaustedError(Exception): + pass + + +NETWORK_RELATED_EXCEPTIONS: Tuple[Type[Exception], ...] = ( + httpx.ConnectError, + httpx.ReadTimeout, + httpx.ConnectTimeout, + httpx.WriteTimeout, + httpx.PoolTimeout, + httpx.TimeoutException, + httpx.NetworkError, + httpx.TransportError, + httpx.ProtocolError, + ConnectionError, + TimeoutError, +) + + +def is_network_exception(exception: Exception) -> bool: + return isinstance(exception, NETWORK_RELATED_EXCEPTIONS) + + +def auto_retry( + max_retries: int = 3, + base_delay: float = 1.0, + enable_config: bool = True, +): + def decorator(func: Callable): + @wraps(func) + async def async_wrapper(*args, **kwargs): + if enable_config and not config.ENABLE_AUTO_RETRY: + return await func(*args, **kwargs) + + last_exception = None + for attempt in range(max_retries + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + if not is_network_exception(e): + utils.logger.warning( + f"[auto_retry] Non-network exception occurred, not retrying: " + f"{e.__class__.__name__}: {str(e)}" + ) + raise + + if attempt >= max_retries: + utils.logger.error( + f"[auto_retry] All {max_retries} retry attempts failed. " + f"Last exception: {e.__class__.__name__}: {str(e)}" + ) + raise RetryExhaustedError( + f"All {max_retries} retry attempts failed. " + f"Last exception: {e.__class__.__name__}: {str(e)}" + ) from e + + delay = base_delay * (2 ** attempt) + utils.logger.warning( + f"[auto_retry] Network exception occurred: {e.__class__.__name__}: {str(e)}. " + f"Retry {attempt + 1}/{max_retries} in {delay}s..." + ) + await asyncio.sleep(delay) + + if last_exception: + raise last_exception + + @wraps(func) + def sync_wrapper(*args, **kwargs): + if enable_config and not config.ENABLE_AUTO_RETRY: + return func(*args, **kwargs) + + last_exception = None + import time + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + if not is_network_exception(e): + utils.logger.warning( + f"[auto_retry] Non-network exception occurred, not retrying: " + f"{e.__class__.__name__}: {str(e)}" + ) + raise + + if attempt >= max_retries: + utils.logger.error( + f"[auto_retry] All {max_retries} retry attempts failed. " + f"Last exception: {e.__class__.__name__}: {str(e)}" + ) + raise RetryExhaustedError( + f"All {max_retries} retry attempts failed. " + f"Last exception: {e.__class__.__name__}: {str(e)}" + ) from e + + delay = base_delay * (2 ** attempt) + utils.logger.warning( + f"[auto_retry] Network exception occurred: {e.__class__.__name__}: {str(e)}. " + f"Retry {attempt + 1}/{max_retries} in {delay}s..." + ) + time.sleep(delay) + + if last_exception: + raise last_exception + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator From 0775e012b7cc656fd2e78a3fe297d5f14be38d44 Mon Sep 17 00:00:00 2001 From: CGW406 <13565294+cgw406@user.noreply.gitee.com> Date: Wed, 22 Apr 2026 17:58:51 +0800 Subject: [PATCH 2/2] =?UTF-8?q?test(retry):=20=E6=B7=BB=E5=8A=A0=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E9=87=8D=E8=AF=95=E8=A3=85=E9=A5=B0=E5=99=A8=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加验证自动重试装饰器功能的测试脚本和单元测试,包括: 1. 网络异常触发重试 2. 指数退避时间验证 3. 业务异常不触发重试 4. 配置开关控制重试功能 5. 同步/异步函数支持 --- tests/test_retry_decorator.py | 349 ++++++++++++++++++++++++++++++++++ verify_retry.py | 267 ++++++++++++++++++++++++++ 2 files changed, 616 insertions(+) create mode 100644 tests/test_retry_decorator.py create mode 100644 verify_retry.py diff --git a/tests/test_retry_decorator.py b/tests/test_retry_decorator.py new file mode 100644 index 000000000..9529fa7cb --- /dev/null +++ b/tests/test_retry_decorator.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2025 relakkes@gmail.com +# +# This file is part of MediaCrawler project. +# Repository: https://github.com/NanmiCoder/MediaCrawler/blob/main/tests/test_retry_decorator.py +# GitHub: https://github.com/NanmiCoder +# Licensed under NON-COMMERCIAL LEARNING LICENSE 1.1 +# +# 声明:本代码仅供学习和研究目的使用。使用者应遵守以下原则: +# 1. 不得用于任何商业用途。 +# 2. 使用时应遵守目标平台的使用条款和robots.txt规则。 +# 3. 不得进行大规模爬取或对平台造成运营干扰。 +# 4. 应合理控制请求频率,避免给目标平台带来不必要的负担。 +# 5. 不得用于任何非法或不当的用途。 +# +# 详细许可条款请参阅项目根目录下的LICENSE文件。 +# 使用本代码即表示您同意遵守上述原则和LICENSE中的所有条款。 + +""" +Tests for auto_retry decorator +""" + +import asyncio +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +import config +from tools.retry_decorator import ( + RetryExhaustedError, + auto_retry, + is_network_exception, + NETWORK_RELATED_EXCEPTIONS, +) +from media_platform.xhs.exception import DataFetchError + + +class TestNetworkExceptionDetection: + """Tests for network exception detection""" + + @pytest.mark.parametrize( + "exception_class,expected", + [ + (httpx.ConnectError, True), + (httpx.ReadTimeout, True), + (httpx.ConnectTimeout, True), + (httpx.WriteTimeout, True), + (httpx.PoolTimeout, True), + (httpx.TimeoutException, True), + (httpx.NetworkError, True), + (httpx.TransportError, True), + (httpx.ProtocolError, True), + (ConnectionError, True), + (TimeoutError, True), + ], + ) + def test_network_exceptions_are_detected(self, exception_class, expected): + exc = exception_class("test error") + assert is_network_exception(exc) == expected + + def test_non_network_exception_not_detected(self): + exc = DataFetchError("business logic error") + assert not is_network_exception(exc) + + def test_general_exception_not_detected(self): + exc = ValueError("some error") + assert not is_network_exception(exc) + + +class TestRetryDecoratorAsync: + """Tests for auto_retry decorator on async functions""" + + @pytest.mark.asyncio + async def test_success_on_first_attempt(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=1.0, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + return "success" + + result = await test_func() + assert result == "success" + assert call_count == 1 + + @pytest.mark.asyncio + async def test_retry_on_network_exception(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise httpx.ReadTimeout("timeout") + return "success" + + result = await test_func() + assert result == "success" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_retry_exhausted_raises_retry_exhausted_error(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + raise httpx.ReadTimeout("timeout") + + with pytest.raises(RetryExhaustedError) as exc_info: + await test_func() + + assert call_count == 4 + assert "ReadTimeout" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_no_retry_on_business_exception(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + raise DataFetchError("business error") + + with pytest.raises(DataFetchError): + await test_func() + + assert call_count == 1 + + @pytest.mark.asyncio + async def test_exponential_backoff_timing(self): + sleep_times = [] + + @auto_retry(max_retries=3, base_delay=1.0, enable_config=False) + async def test_func(): + raise httpx.ReadTimeout("timeout") + + original_sleep = asyncio.sleep + + async def mock_sleep(delay): + sleep_times.append(delay) + + with patch.object(asyncio, 'sleep', side_effect=mock_sleep): + with pytest.raises(RetryExhaustedError): + await test_func() + + assert sleep_times == [1.0, 2.0, 4.0] + + @pytest.mark.asyncio + async def test_config_switch_disabled_no_retry(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=True) + async def test_func(): + nonlocal call_count + call_count += 1 + raise httpx.ReadTimeout("timeout") + + with patch.object(config, 'ENABLE_AUTO_RETRY', False): + with pytest.raises(httpx.ReadTimeout): + await test_func() + + assert call_count == 1 + + @pytest.mark.asyncio + async def test_config_switch_enabled_retry(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=True) + async def test_func(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise httpx.ReadTimeout("timeout") + return "success" + + with patch.object(config, 'ENABLE_AUTO_RETRY', True): + result = await test_func() + + assert result == "success" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_preserves_function_metadata(self): + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def my_test_function(arg1, arg2="default"): + """This is a test docstring""" + return f"{arg1}-{arg2}" + + assert my_test_function.__name__ == "my_test_function" + assert my_test_function.__doc__ == "This is a test docstring" + + result = await my_test_function("hello", arg2="world") + assert result == "hello-world" + + +class TestRetryDecoratorSync: + """Tests for auto_retry decorator on sync functions""" + + def test_success_on_first_attempt(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=1.0, enable_config=False) + def test_func(): + nonlocal call_count + call_count += 1 + return "success" + + result = test_func() + assert result == "success" + assert call_count == 1 + + def test_retry_on_network_exception(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + def test_func(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise httpx.ReadTimeout("timeout") + return "success" + + result = test_func() + assert result == "success" + assert call_count == 3 + + def test_no_retry_on_business_exception(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + def test_func(): + nonlocal call_count + call_count += 1 + raise DataFetchError("business error") + + with pytest.raises(DataFetchError): + test_func() + + assert call_count == 1 + + def test_exponential_backoff_timing(self): + sleep_times = [] + import time + + @auto_retry(max_retries=3, base_delay=1.0, enable_config=False) + def test_func(): + raise httpx.ReadTimeout("timeout") + + original_sleep = time.sleep + + def mock_sleep(delay): + sleep_times.append(delay) + + with patch.object(time, 'sleep', side_effect=mock_sleep): + with pytest.raises(RetryExhaustedError): + test_func() + + assert sleep_times == [1.0, 2.0, 4.0] + + +class TestRetryLogging: + """Tests for retry logging""" + + @pytest.mark.asyncio + async def test_logs_retry_attempts(self): + from tools import utils + + log_messages = [] + + def mock_warning(msg): + log_messages.append(msg) + + original_warning = utils.logger.warning + + @auto_retry(max_retries=2, base_delay=0.1, enable_config=False) + async def test_func(): + raise httpx.ReadTimeout("test timeout error") + + with patch.object(utils.logger, 'warning', side_effect=mock_warning): + with pytest.raises(RetryExhaustedError): + await test_func() + + assert len(log_messages) == 2 + assert "Network exception occurred" in log_messages[0] + assert "ReadTimeout" in log_messages[0] + assert "Retry 1/2" in log_messages[0] + assert "Retry 2/2" in log_messages[1] + + +class TestIntegrationScenarios: + """Integration tests simulating real use cases""" + + @pytest.mark.asyncio + async def test_mixed_exceptions(self): + call_count = 0 + raised_exceptions = [] + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + + if call_count == 1: + exc = httpx.ConnectError("connection failed") + raised_exceptions.append(exc) + raise exc + elif call_count == 2: + exc = httpx.ReadTimeout("timeout") + raised_exceptions.append(exc) + raise exc + else: + return "success" + + result = await test_func() + assert result == "success" + assert call_count == 3 + assert len(raised_exceptions) == 2 + + @pytest.mark.asyncio + async def test_retry_then_business_error(self): + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + + if call_count == 1: + raise httpx.ConnectError("connection failed") + elif call_count == 2: + raise DataFetchError("invalid data") + else: + return "success" + + with pytest.raises(DataFetchError): + await test_func() + + assert call_count == 2 diff --git a/verify_retry.py b/verify_retry.py new file mode 100644 index 000000000..1718f60b9 --- /dev/null +++ b/verify_retry.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +""" +Simple verification script for auto_retry decorator +Run this script to verify: +1. Network exceptions trigger retries +2. Exponential backoff timing (1s, 2s, 4s) +3. Business exceptions do NOT trigger retries +4. Config switch ENABLE_AUTO_RETRY=False disables retries +""" + +import asyncio +import sys +import time +from pathlib import Path + +project_root = Path(__file__).parent +sys.path.insert(0, str(project_root)) + +import httpx +import config + +from tools.retry_decorator import ( + RetryExhaustedError, + auto_retry, + is_network_exception, + NETWORK_RELATED_EXCEPTIONS, +) +from media_platform.xhs.exception import DataFetchError + + +async def test_1_network_exception_retries(): + """Test 1: Network exception triggers retries""" + print("\n" + "=" * 60) + print("Test 1: Network exception triggers retries") + print("=" * 60) + + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + print(f" [Attempt {call_count}] Called, about to raise ReadTimeout...") + if call_count <= 2: + raise httpx.ReadTimeout("Simulated timeout error") + return "success" + + result = await test_func() + print(f" Result: {result}") + print(f" Total calls: {call_count}") + + assert call_count == 3, f"Expected 3 calls, got {call_count}" + print(" ✓ PASSED: Retry mechanism works for network exceptions") + + +async def test_2_exponential_backoff(): + """Test 2: Exponential backoff timing (1s, 2s, 4s)""" + print("\n" + "=" * 60) + print("Test 2: Exponential backoff timing (1s, 2s, 4s)") + print("=" * 60) + + sleep_times = [] + original_sleep = asyncio.sleep + + async def mock_sleep(delay): + sleep_times.append(delay) + print(f" [Mock sleep] Sleeping for {delay}s") + + asyncio.sleep = mock_sleep + + try: + @auto_retry(max_retries=3, base_delay=1.0, enable_config=False) + async def test_func(): + raise httpx.ConnectError("Simulated connection error") + + try: + await test_func() + except RetryExhaustedError as e: + print(f" Expected RetryExhaustedError: {e}") + + print(f" Sleep intervals recorded: {sleep_times}") + assert sleep_times == [1.0, 2.0, 4.0], f"Expected [1.0, 2.0, 4.0], got {sleep_times}" + print(" ✓ PASSED: Exponential backoff timing is correct (1s, 2s, 4s)") + finally: + asyncio.sleep = original_sleep + + +async def test_3_business_exception_no_retry(): + """Test 3: Business exception does NOT trigger retry""" + print("\n" + "=" * 60) + print("Test 3: Business exception (DataFetchError) does NOT trigger retry") + print("=" * 60) + + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.1, enable_config=False) + async def test_func(): + nonlocal call_count + call_count += 1 + print(f" [Attempt {call_count}] Called, about to raise DataFetchError...") + raise DataFetchError("Business logic error: invalid data") + + try: + await test_func() + except DataFetchError as e: + print(f" Expected DataFetchError: {e}") + + print(f" Total calls: {call_count}") + assert call_count == 1, f"Expected 1 call, got {call_count}" + print(" ✓ PASSED: Business exceptions do not trigger retries") + + +async def test_4_config_switch_disabled(): + """Test 4: ENABLE_AUTO_RETRY=False disables retries""" + print("\n" + "=" * 60) + print("Test 4: ENABLE_AUTO_RETRY=False disables retries") + print("=" * 60) + + call_count = 0 + original_value = getattr(config, 'ENABLE_AUTO_RETRY', True) + + config.ENABLE_AUTO_RETRY = False + + try: + @auto_retry(max_retries=3, base_delay=0.1, enable_config=True) + async def test_func(): + nonlocal call_count + call_count += 1 + print(f" [Attempt {call_count}] Called, about to raise ReadTimeout...") + raise httpx.ReadTimeout("Simulated timeout error") + + try: + await test_func() + except httpx.ReadTimeout as e: + print(f" Expected ReadTimeout (no retry): {e}") + + print(f" Total calls: {call_count}") + assert call_count == 1, f"Expected 1 call (no retry), got {call_count}" + print(" ✓ PASSED: Config switch ENABLE_AUTO_RETRY=False disables retries") + finally: + config.ENABLE_AUTO_RETRY = original_value + + +async def test_5_config_switch_enabled(): + """Test 5: ENABLE_AUTO_RETRY=True enables retries""" + print("\n" + "=" * 60) + print("Test 5: ENABLE_AUTO_RETRY=True enables retries") + print("=" * 60) + + call_count = 0 + original_value = getattr(config, 'ENABLE_AUTO_RETRY', True) + + config.ENABLE_AUTO_RETRY = True + + try: + @auto_retry(max_retries=3, base_delay=0.1, enable_config=True) + async def test_func(): + nonlocal call_count + call_count += 1 + print(f" [Attempt {call_count}] Called...") + if call_count <= 2: + raise httpx.ReadTimeout("Simulated timeout error") + return "success" + + result = await test_func() + print(f" Result: {result}") + print(f" Total calls: {call_count}") + assert call_count == 3, f"Expected 3 calls, got {call_count}" + print(" ✓ PASSED: Config switch ENABLE_AUTO_RETRY=True enables retries") + finally: + config.ENABLE_AUTO_RETRY = original_value + + +def test_sync_function(): + """Test 6: Retry decorator works with sync functions""" + print("\n" + "=" * 60) + print("Test 6: Retry decorator works with sync functions") + print("=" * 60) + + call_count = 0 + + @auto_retry(max_retries=3, base_delay=0.01, enable_config=False) + def test_func(): + nonlocal call_count + call_count += 1 + print(f" [Attempt {call_count}] Called...") + if call_count <= 2: + raise httpx.ReadTimeout("Simulated timeout error") + return "sync success" + + result = test_func() + print(f" Result: {result}") + print(f" Total calls: {call_count}") + assert call_count == 3, f"Expected 3 calls, got {call_count}" + print(" ✓ PASSED: Retry decorator works with sync functions") + + +def test_exception_detection(): + """Test 7: Network exception detection""" + print("\n" + "=" * 60) + print("Test 7: Network exception detection") + print("=" * 60) + + network_exceptions = [ + httpx.ConnectError("test"), + httpx.ReadTimeout("test"), + httpx.ConnectTimeout("test"), + httpx.WriteTimeout("test"), + httpx.PoolTimeout("test"), + httpx.TimeoutException("test"), + httpx.NetworkError("test"), + httpx.TransportError("test"), + httpx.ProtocolError("test"), + ConnectionError("test"), + TimeoutError("test"), + ] + + non_network_exceptions = [ + DataFetchError("business error"), + ValueError("test"), + KeyError("test"), + TypeError("test"), + ] + + print(" Checking network exceptions:") + for exc in network_exceptions: + result = is_network_exception(exc) + print(f" {exc.__class__.__name__}: is_network_exception = {result}") + assert result, f"Expected {exc.__class__.__name__} to be detected as network exception" + + print(" Checking non-network exceptions:") + for exc in non_network_exceptions: + result = is_network_exception(exc) + print(f" {exc.__class__.__name__}: is_network_exception = {result}") + assert not result, f"Expected {exc.__class__.__name__} NOT to be detected as network exception" + + print(" ✓ PASSED: Network exception detection is correct") + + +async def main(): + print("=" * 60) + print("Auto Retry Decorator Verification") + print("=" * 60) + + test_exception_detection() + test_sync_function() + + await test_1_network_exception_retries() + await test_2_exponential_backoff() + await test_3_business_exception_no_retry() + await test_4_config_switch_disabled() + await test_5_config_switch_enabled() + + print("\n" + "=" * 60) + print("All tests PASSED!") + print("=" * 60) + print("\nSummary:") + print(" ✓ Network exceptions trigger retries (3 times)") + print(" ✓ Exponential backoff timing: 1s, 2s, 4s") + print(" ✓ Business exceptions (DataFetchError) do NOT trigger retries") + print(" ✓ Config switch ENABLE_AUTO_RETRY=False disables retries") + print(" ✓ Config switch ENABLE_AUTO_RETRY=True enables retries") + print(" ✓ Works with both sync and async functions") + + +if __name__ == "__main__": + asyncio.run(main())