diff --git a/src/marketdata/api_error.py b/src/marketdata/api_error.py index a8e4173..c93c55e 100644 --- a/src/marketdata/api_error.py +++ b/src/marketdata/api_error.py @@ -1,14 +1,12 @@ from functools import wraps +from logging import DEBUG from typing import TYPE_CHECKING, Callable +from tenacity import before_sleep_log + from marketdata.api_status import API_STATUS_DATA, APIStatusResult from marketdata.exceptions import RequestError -from marketdata.internal_settings import ( - MAX_RETRY_ATTEMPTS, - MAX_RETRY_BACKOFF, - MIN_RETRY_BACKOFF, - RETRY_BACKOFF, -) +from marketdata.internal_settings import INITIAL_RETRY_DELAY from marketdata.resources.base import BaseResource from marketdata.retry import get_retry_adapter @@ -25,25 +23,22 @@ def wrapper(*args, **kwargs): resource: BaseResource = args[0] client: "MarketDataClient" = resource.client logger = client.logger + log_before_sleep = before_sleep_log(logger, log_level=DEBUG) - try: - return func(*args, **kwargs) - except RequestError as e: - if API_STATUS_DATA.should_refresh: - API_STATUS_DATA.refresh(client) - + def _status_check_before_sleep(retry_state): status = API_STATUS_DATA.get_api_status(client, service) - if status in (APIStatusResult.ONLINE, APIStatusResult.UNKNOWN): - retry_adapter = get_retry_adapter( - attempts=MAX_RETRY_ATTEMPTS, - backoff=RETRY_BACKOFF, - exceptions=[RequestError], - logger=logger, - reraise=True, - min_backoff=MIN_RETRY_BACKOFF, - max_backoff=MAX_RETRY_BACKOFF, - ) - return retry_adapter(func, *args, **kwargs) - raise e + if status == APIStatusResult.OFFLINE: + raise retry_state.outcome.exception() + log_before_sleep(retry_state) + + retry_adapter = get_retry_adapter( + attempts=client.max_retries + 1, + initial_delay=INITIAL_RETRY_DELAY, + exceptions=[RequestError], + logger=logger, + reraise=True, + before_sleep=_status_check_before_sleep, + ) + return retry_adapter(func, *args, **kwargs) return wrapper diff --git a/src/marketdata/api_status.py b/src/marketdata/api_status.py index f803528..dcef744 100644 --- a/src/marketdata/api_status.py +++ b/src/marketdata/api_status.py @@ -1,10 +1,14 @@ import datetime import logging +import threading from enum import Enum from typing import TYPE_CHECKING from marketdata.exceptions import BadStatusCodeError, InvalidStatusDataError -from marketdata.internal_settings import REFRESH_API_STATUS_INTERVAL +from marketdata.internal_settings import ( + CACHE_VALIDITY_INTERVAL, + REFRESH_API_STATUS_INTERVAL, +) if TYPE_CHECKING: from marketdata.client import MarketDataClient @@ -18,54 +22,67 @@ class APIStatusResult(str, Enum): class APIStatusData: def __init__(self): + self._lock = threading.Lock() + self._refresh_in_flight = False + self._refresh_thread: threading.Thread | None = None + self._last_refresh_at: datetime.datetime | None = None self.service = [] self.status = [] self.online = [] - self.uptimePct30d = [] - self.uptimePct90d = [] - self.updated = [] def update(self, data: dict): try: - self.service = data["service"] - self.status = data["status"] - self.online = data["online"] - self.uptimePct30d = data["uptimePct30d"] - self.uptimePct90d = data["uptimePct90d"] - self.updated = data["updated"] + new_service = data["service"] + new_status = data["status"] + new_online = data["online"] except KeyError as e: raise InvalidStatusDataError(f"Invalid status data: {e}") from e + with self._lock: + self.service = new_service + self.status = new_status + self.online = new_online + self._last_refresh_at = datetime.datetime.now() @property - def last_updated(self) -> datetime.datetime: - if not self.updated: - return datetime.datetime(1970, 1, 1) - return datetime.datetime.fromtimestamp(min(self.updated)) + def cache_age(self) -> datetime.timedelta: + if self._last_refresh_at is None: + return datetime.timedelta.max + return datetime.datetime.now() - self._last_refresh_at @property def should_refresh(self) -> bool: - return datetime.datetime.now() - self.last_updated > REFRESH_API_STATUS_INTERVAL + return self.cache_age >= REFRESH_API_STATUS_INTERVAL + + @property + def is_cache_stale(self) -> bool: + return self.cache_age >= CACHE_VALIDITY_INTERVAL def get_api_status( self, client: "MarketDataClient", service: str ) -> APIStatusResult: client.logger.debug(f"Checking if service {service} is online") - if self.should_refresh and not self.refresh(client): - return APIStatusResult.UNKNOWN - if service not in self.service: - client.logger.error(f"Service {service} not found in API status") + if self.is_cache_stale: + self._trigger_async_refresh(client) return APIStatusResult.UNKNOWN - service_index = self.service.index(service) - if self.status[service_index] != APIStatusResult.ONLINE: - client.logger.error(f"Service {service} is offline") - return APIStatusResult.OFFLINE - if not self.online[service_index]: - client.logger.error(f"Service {service} is not online") - return APIStatusResult.OFFLINE - client.logger.debug(f"Service {service} is online") - return APIStatusResult.ONLINE + if self.should_refresh: + self._trigger_async_refresh(client) + + with self._lock: + if service not in self.service: + client.logger.error(f"Service {service} not found in API status") + return APIStatusResult.UNKNOWN + + service_index = self.service.index(service) + if self.status[service_index] != APIStatusResult.ONLINE: + client.logger.error(f"Service {service} is offline") + return APIStatusResult.OFFLINE + if not self.online[service_index]: + client.logger.error(f"Service {service} is not online") + return APIStatusResult.OFFLINE + client.logger.debug(f"Service {service} is online") + return APIStatusResult.ONLINE def refresh(self, client: "MarketDataClient") -> bool: try: @@ -86,5 +103,31 @@ def refresh(self, client: "MarketDataClient") -> bool: client.logger.error(f"Failed to refresh API status: {e}") return False + def _trigger_async_refresh(self, client: "MarketDataClient") -> None: + with self._lock: + if self._refresh_in_flight: + return + self._refresh_in_flight = True + + try: + thread = threading.Thread( + target=self._async_refresh, args=(client,), daemon=True + ) + self._refresh_thread = thread + thread.start() + except Exception: + with self._lock: + self._refresh_in_flight = False + raise + + def _async_refresh(self, client: "MarketDataClient") -> None: + try: + self.refresh(client) + except Exception: + client.logger.exception("Async status refresh failed") + finally: + with self._lock: + self._refresh_in_flight = False + API_STATUS_DATA = APIStatusData() diff --git a/src/marketdata/client.py b/src/marketdata/client.py index ac71f05..24cb9b5 100644 --- a/src/marketdata/client.py +++ b/src/marketdata/client.py @@ -8,6 +8,7 @@ from marketdata.input_types.base import UserUniversalAPIParams from marketdata.internal_settings import ( HTTP_TIMEOUT, + MAX_RETRY_ATTEMPTS, NO_TOKEN_VALUE, RETRY_STATUS_CODES, ) @@ -23,8 +24,16 @@ class MarketDataClient: - def __init__(self, token: str = None, logger: Logger = None): + def __init__( + self, + token: str = None, + logger: Logger = None, + max_retries: int = MAX_RETRY_ATTEMPTS, + ): + if max_retries < 0: + raise ValueError("max_retries must be >= 0") self.token = token or settings.marketdata_token + self.max_retries = max_retries self.library_version = version("marketdata-sdk-py") self.library_user_agent = self._get_user_agent() diff --git a/src/marketdata/internal_settings.py b/src/marketdata/internal_settings.py index 03c558d..581cf58 100644 --- a/src/marketdata/internal_settings.py +++ b/src/marketdata/internal_settings.py @@ -7,14 +7,13 @@ class NoTokenValueType: MAX_CONCURRENT_REQUESTS = 50 MAX_RETRY_ATTEMPTS = 3 -RETRY_BACKOFF = 0.5 +INITIAL_RETRY_DELAY = 1.0 RETRY_STATUS_CODES = lambda x: x > 500 HTTP_TIMEOUT = 60 -MIN_RETRY_BACKOFF = 0.5 -MAX_RETRY_BACKOFF = 5 VALID_STATUS_CODES = [200, 203] GLOBAL_EXCLUDED_PARAMS = ["output_format", "filename"] REFRESH_API_STATUS_INTERVAL = datetime.timedelta(minutes=4, seconds=30) +CACHE_VALIDITY_INTERVAL = datetime.timedelta(minutes=5) ALLOWED_POSITIONAL_PARAMS = ["symbol", "symbols", "lookup"] DATAFRAME_HANDLERS_PRIORITY = ["pandas", "polars"] NO_TOKEN_VALUE = NoTokenValueType() diff --git a/src/marketdata/retry.py b/src/marketdata/retry.py index 6d5b53a..316b991 100644 --- a/src/marketdata/retry.py +++ b/src/marketdata/retry.py @@ -1,31 +1,66 @@ +import math +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from logging import DEBUG, Logger +from typing import Callable from tenacity import ( Retrying, before_sleep_log, retry_if_exception_type, stop_after_attempt, - wait_exponential, ) +def parse_retry_after(value: str | None) -> float | None: + if not value: + return None + value = value.strip() + try: + parsed = float(value) + except ValueError: + parsed = None + if parsed is not None: + if not math.isfinite(parsed): + return None + return max(0.0, parsed) + try: + target = parsedate_to_datetime(value) + except (TypeError, ValueError): + return None + if target.tzinfo is None: + target = target.replace(tzinfo=timezone.utc) + return max(0.0, (target - datetime.now(timezone.utc)).total_seconds()) + + def get_retry_adapter( attempts: int, - backoff: float, + initial_delay: float, logger: Logger, exceptions: list[Exception] = None, reraise: bool = False, - min_backoff: float = 0.5, - max_backoff: float = 5, + before_sleep: Callable = None, ) -> Retrying: if not exceptions: exceptions = [Exception] + def _compute_wait(retry_state) -> float: + exc = retry_state.outcome.exception() if retry_state.outcome else None + if exc is not None: + response = getattr(exc, "response", None) + if response is not None: + retry_after = parse_retry_after( + response.headers.get("Retry-After") + ) + if retry_after is not None: + return retry_after + return initial_delay * 2 ** (retry_state.attempt_number - 1) + return Retrying( stop=stop_after_attempt(attempts), - wait=wait_exponential(multiplier=backoff, min=min_backoff, max=max_backoff), - retry=retry_if_exception_type(*exceptions), + wait=_compute_wait, + retry=retry_if_exception_type(tuple(exceptions)), reraise=reraise, - before_sleep=before_sleep_log(logger, log_level=DEBUG), + before_sleep=before_sleep or before_sleep_log(logger, log_level=DEBUG), ) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 22acd10..a61624f 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -4,6 +4,7 @@ import pytest +from marketdata.api_status import API_STATUS_DATA from marketdata.client import MarketDataClient from marketdata.types import UserRateLimits @@ -25,6 +26,15 @@ def chdir(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): monkeypatch.chdir(tmp_path) +@pytest.fixture(autouse=True) +def _reset_api_status_data(): + API_STATUS_DATA.__init__() + yield + if API_STATUS_DATA._refresh_thread is not None: + API_STATUS_DATA._refresh_thread.join(timeout=2) + API_STATUS_DATA.__init__() + + @pytest.fixture def client(respx_mock): diff --git a/src/tests/test_api_error.py b/src/tests/test_api_error.py index 25e529f..482a34a 100644 --- a/src/tests/test_api_error.py +++ b/src/tests/test_api_error.py @@ -4,65 +4,74 @@ from httpx import Request, Response from marketdata.api_error import api_error_handler +from marketdata.api_status import APIStatusResult from marketdata.exceptions import RequestError from marketdata.resources.base import BaseResource -from src.marketdata.api_status import APIStatusResult class DummyResource(BaseResource): + call_count = 0 + @api_error_handler def test_function_fails(self): + DummyResource.call_count += 1 request = Request(method="GET", url="https://example.com") response = Response(status_code=500) raise RequestError("test exception", request=request, response=response) +@pytest.fixture(autouse=True) +def _reset_dummy(): + DummyResource.call_count = 0 + yield + DummyResource.call_count = 0 + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch): + monkeypatch.setattr("time.sleep", lambda *_: None) + + @patch( "marketdata.api_error.API_STATUS_DATA.get_api_status", return_value=APIStatusResult.OFFLINE, ) -@patch( - "marketdata.api_error.get_retry_adapter", - return_value=lambda x, *args, **kwargs: x(*args, **kwargs), -) -def test_api_error_handler_fails_when_api_is_offline( - retry_adapter, api_status_data, client -): +def test_api_error_handler_offline_aborts_after_first_failure(_, client): resource = DummyResource(client=client) with pytest.raises(RequestError): resource.test_function_fails() - retry_adapter.assert_not_called() + assert DummyResource.call_count == 1 @patch( "marketdata.api_error.API_STATUS_DATA.get_api_status", return_value=APIStatusResult.ONLINE, ) -@patch( - "marketdata.api_error.get_retry_adapter", - return_value=lambda x, *args, **kwargs: x(*args, **kwargs), -) -def test_api_error_handler_fails_when_api_is_online( - retry_adapter, api_status_data, client -): +def test_api_error_handler_online_retries_max_attempts(_, client): resource = DummyResource(client=client) with pytest.raises(RequestError): resource.test_function_fails() - retry_adapter.assert_called_once() + assert DummyResource.call_count == 4 @patch( "marketdata.api_error.API_STATUS_DATA.get_api_status", return_value=APIStatusResult.UNKNOWN, ) +def test_api_error_handler_unknown_retries_max_attempts(_, client): + resource = DummyResource(client=client) + with pytest.raises(RequestError): + resource.test_function_fails() + assert DummyResource.call_count == 4 + + @patch( - "marketdata.api_error.get_retry_adapter", - return_value=lambda x, *args, **kwargs: x(*args, **kwargs), + "marketdata.api_error.API_STATUS_DATA.get_api_status", + return_value=APIStatusResult.ONLINE, ) -def test_api_error_handler_fails_when_api_is_unknown( - retry_adapter, api_status_data, client -): +def test_api_error_handler_respects_max_retries_zero(_, client): + client.max_retries = 0 resource = DummyResource(client=client) with pytest.raises(RequestError): resource.test_function_fails() - retry_adapter.assert_called_once() + assert DummyResource.call_count == 1 diff --git a/src/tests/test_api_status.py b/src/tests/test_api_status.py index eeb5272..cfd1095 100644 --- a/src/tests/test_api_status.py +++ b/src/tests/test_api_status.py @@ -1,156 +1,64 @@ +import datetime +import time +from unittest.mock import MagicMock + +import pytest + from marketdata.api_status import API_STATUS_DATA, APIStatusResult +from marketdata.exceptions import InvalidStatusDataError +from marketdata.internal_settings import ( + CACHE_VALIDITY_INTERVAL, + REFRESH_API_STATUS_INTERVAL, +) + +ALL_SERVICES = [ + "/v1/markets/status/", + "/v1/options/chain/", + "/v1/options/expirations/", + "/v1/options/lookup/", + "/v1/options/quotes/", + "/v1/options/strikes/", + "/v1/stocks/bulkcandles/", + "/v1/stocks/bulkquotes/", + "/v1/stocks/candles/", + "/v1/stocks/earnings/", + "/v1/stocks/news/", + "/v1/stocks/quotes/", +] def test_api_status_data(load_json, respx_mock, client): mock_data = load_json("api_status_response_200") - respx_mock.get("https://api.marketdata.app/status/").respond( - json=mock_data, - status_code=200, + json=mock_data, status_code=200 ) API_STATUS_DATA.refresh(client) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/chain/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/expirations/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/lookup/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/quotes/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/strikes/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/bulkcandles/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/bulkquotes/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/candles/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/earnings/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/news/") - == APIStatusResult.ONLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/quotes/") - == APIStatusResult.ONLINE - ) + for service in ALL_SERVICES: + assert ( + API_STATUS_DATA.get_api_status(client, service) + == APIStatusResult.ONLINE + ) def test_api_status_data_offline(load_json, respx_mock, client): mock_data = load_json("api_status_response_200") - - mock_data["status"] = [ - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - "offline", - ] - mock_data["online"] = [ - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - ] + mock_data["status"] = ["offline"] * len(mock_data["service"]) + mock_data["online"] = [False] * len(mock_data["service"]) respx_mock.get("https://api.marketdata.app/status/").respond( - json=mock_data, - status_code=200, + json=mock_data, status_code=200 ) API_STATUS_DATA.refresh(client) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/chain/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/expirations/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/lookup/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/quotes/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/options/strikes/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/bulkcandles/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/bulkquotes/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/candles/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/earnings/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/news/") - == APIStatusResult.OFFLINE - ) - assert ( - API_STATUS_DATA.get_api_status(client, "/v1/stocks/quotes/") - == APIStatusResult.OFFLINE - ) + for service in ALL_SERVICES: + assert ( + API_STATUS_DATA.get_api_status(client, service) + == APIStatusResult.OFFLINE + ) def test_api_status_data_unknown(respx_mock, client): - respx_mock.get("https://api.marketdata.app/status/").respond( - status_code=500, - ) + respx_mock.get("https://api.marketdata.app/status/").respond(status_code=500) API_STATUS_DATA.refresh(client) assert ( @@ -172,8 +80,11 @@ def test_api_status_data_service_not_online(respx_mock, client): }, status_code=200, ) - status = API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") - assert status == APIStatusResult.OFFLINE + API_STATUS_DATA.refresh(client) + assert ( + API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") + == APIStatusResult.OFFLINE + ) def test_api_status_data_service_not_found(respx_mock, client): @@ -183,6 +94,9 @@ def test_api_status_data_service_not_found(respx_mock, client): "service": ["/v1/markets/status/"], "status": ["online"], "online": [True], + "uptimePct30d": [100], + "uptimePct90d": [100], + "updated": [int(time.time())], }, status_code=200, ) @@ -192,3 +106,157 @@ def test_api_status_data_service_not_found(respx_mock, client): API_STATUS_DATA.get_api_status(client, "invalid_service") == APIStatusResult.UNKNOWN ) + + +def _populate_cache_with_age(age: datetime.timedelta): + API_STATUS_DATA.service = ["/v1/markets/status/"] + API_STATUS_DATA.status = ["online"] + API_STATUS_DATA.online = [True] + API_STATUS_DATA._last_refresh_at = datetime.datetime.now() - age + + +def test_get_api_status_fresh_cache_uses_cached_no_refresh(client, monkeypatch): + _populate_cache_with_age(datetime.timedelta(seconds=10)) + triggered = [] + monkeypatch.setattr( + API_STATUS_DATA, "_trigger_async_refresh", lambda c: triggered.append(c) + ) + + status = API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") + assert status == APIStatusResult.ONLINE + assert triggered == [] + + +def test_get_api_status_in_refresh_zone_uses_cache_and_triggers_refresh( + client, monkeypatch +): + age = REFRESH_API_STATUS_INTERVAL + datetime.timedelta(seconds=5) + assert age < CACHE_VALIDITY_INTERVAL + _populate_cache_with_age(age) + triggered = [] + monkeypatch.setattr( + API_STATUS_DATA, "_trigger_async_refresh", lambda c: triggered.append(c) + ) + + status = API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") + assert status == APIStatusResult.ONLINE + assert triggered == [client] + + +def test_get_api_status_stale_cache_returns_unknown_and_triggers_refresh( + client, monkeypatch +): + _populate_cache_with_age(CACHE_VALIDITY_INTERVAL + datetime.timedelta(seconds=1)) + triggered = [] + monkeypatch.setattr( + API_STATUS_DATA, "_trigger_async_refresh", lambda c: triggered.append(c) + ) + + status = API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") + assert status == APIStatusResult.UNKNOWN + assert triggered == [client] + + +def test_get_api_status_empty_cache_returns_unknown_and_triggers_refresh( + client, monkeypatch +): + triggered = [] + monkeypatch.setattr( + API_STATUS_DATA, "_trigger_async_refresh", lambda c: triggered.append(c) + ) + + status = API_STATUS_DATA.get_api_status(client, "/v1/markets/status/") + assert status == APIStatusResult.UNKNOWN + assert triggered == [client] + + +def test_trigger_async_refresh_skips_when_in_flight(client): + API_STATUS_DATA._refresh_in_flight = True + spawned = [] + + def fake_thread(*args, **kwargs): + spawned.append(kwargs) + m = MagicMock() + return m + + import marketdata.api_status as mod + original_thread = mod.threading.Thread + mod.threading.Thread = fake_thread + try: + API_STATUS_DATA._trigger_async_refresh(client) + assert spawned == [] + finally: + mod.threading.Thread = original_thread + + +def test_trigger_async_refresh_runs_in_background(respx_mock, client): + respx_mock.get("https://api.marketdata.app/status/").respond( + json={ + "service": ["/v1/markets/status/"], + "status": ["online"], + "online": [True], + "uptimePct30d": [100], + "uptimePct90d": [100], + "updated": [int(time.time())], + }, + status_code=200, + ) + + API_STATUS_DATA._trigger_async_refresh(client) + thread = API_STATUS_DATA._refresh_thread + assert thread is not None + thread.join(timeout=5) + assert not thread.is_alive() + assert API_STATUS_DATA._refresh_in_flight is False + assert API_STATUS_DATA.service == ["/v1/markets/status/"] + + +@pytest.mark.parametrize("missing_key", ["service", "status", "online"]) +def test_update_raises_when_required_key_missing(missing_key): + data = {"service": ["x"], "status": ["online"], "online": [True]} + data.pop(missing_key) + with pytest.raises(InvalidStatusDataError, match=missing_key): + API_STATUS_DATA.update(data) + + +def test_async_refresh_clears_in_flight_on_failure(respx_mock, client): + respx_mock.get("https://api.marketdata.app/status/").respond(status_code=500) + + API_STATUS_DATA._trigger_async_refresh(client) + thread = API_STATUS_DATA._refresh_thread + thread.join(timeout=5) + assert API_STATUS_DATA._refresh_in_flight is False + + +def test_async_refresh_logs_unexpected_exception(client, monkeypatch): + def boom(_): + raise RuntimeError("network exploded") + + monkeypatch.setattr(API_STATUS_DATA, "refresh", boom) + logged = [] + monkeypatch.setattr( + client.logger, "exception", lambda msg, *a, **kw: logged.append(msg) + ) + + API_STATUS_DATA._trigger_async_refresh(client) + thread = API_STATUS_DATA._refresh_thread + thread.join(timeout=5) + + assert logged == ["Async status refresh failed"] + assert API_STATUS_DATA._refresh_in_flight is False + + +def test_trigger_async_refresh_clears_flag_when_thread_construction_fails( + client, monkeypatch +): + def boom(*args, **kwargs): + raise RuntimeError("cannot spawn") + + import marketdata.api_status as mod + + monkeypatch.setattr(mod.threading, "Thread", boom) + + with pytest.raises(RuntimeError): + API_STATUS_DATA._trigger_async_refresh(client) + + assert API_STATUS_DATA._refresh_in_flight is False diff --git a/src/tests/test_client.py b/src/tests/test_client.py index ec4223d..ee5b4f4 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -14,29 +14,12 @@ ) from marketdata.input_types.base import OutputFormat from marketdata.internal_settings import NO_TOKEN_VALUE -from marketdata.retry import get_retry_adapter from marketdata.sdk_error import MarketDataClientErrorResult from marketdata.settings import MarketDataSettings, settings from marketdata.types import UserRateLimits from marketdata.utils import format_duration_log -def test_get_retry_adapter(client): - retry_adapter = get_retry_adapter( - attempts=3, - backoff=0.5, - exceptions=[], - logger=client.logger, - ) - assert retry_adapter is not None - assert retry_adapter.stop.max_attempt_number == 3 - assert retry_adapter.wait.multiplier == 0.5 - assert retry_adapter.retry.exception_types == Exception - assert retry_adapter.reraise == False - assert retry_adapter.wait.min == 0.5 - assert retry_adapter.wait.max == 5 - - def test_user_rate_limits_str(): user_rate_limits = UserRateLimits( requests_limit=100, @@ -71,7 +54,16 @@ def test_client_headers_no_token(respx_mock): } -def test_client_make_request_retry(client, respx_mock): +def test_client_make_request_retry(client, respx_mock, monkeypatch): + monkeypatch.setattr("time.sleep", lambda *_: None) + from marketdata.api_status import API_STATUS_DATA + + monkeypatch.setattr( + API_STATUS_DATA, + "_trigger_async_refresh", + lambda c: API_STATUS_DATA.refresh(c), + ) + respx_mock.get("https://api.marketdata.app/v1/stocks/prices/").respond( json={}, status_code=502, @@ -80,22 +72,16 @@ def test_client_make_request_retry(client, respx_mock): result = client.stocks.prices(symbols="AAPL") assert isinstance(result, MarketDataClientErrorResult) + prices_calls = [ + c for c in respx_mock.calls if c.request.url.path == "/v1/stocks/prices/" + ] + status_calls = [ + c for c in respx_mock.calls if c.request.url.path == "/status/" + ] + assert len(prices_calls) == 4 + assert len(status_calls) == 1 assert respx_mock.calls.call_count == 6 - # 1st request is for user rate limits - assert respx_mock.calls[0].request.url.path == "/user/" - - # 2nd request is stocks.prices (and it fails with 502 status code) - assert respx_mock.calls[1].request.url.path == "/v1/stocks/prices/" - - # 3rd request is API status check - assert respx_mock.calls[2].request.url.path == "/status/" - - # 4th, 5th, 6th requests are retries - assert respx_mock.calls[3].request.url.path == "/v1/stocks/prices/" - assert respx_mock.calls[4].request.url.path == "/v1/stocks/prices/" - assert respx_mock.calls[5].request.url.path == "/v1/stocks/prices/" - def test_client_make_request_bad_status_not_retry(client, respx_mock): respx_mock.get("https://api.marketdata.app/v1/stocks/prices/").respond( @@ -319,6 +305,107 @@ def test_client_pre_and_post_request_logs(client, respx_mock): ) +def test_client_max_retries_default(client): + assert client.max_retries == 3 + + +def test_client_max_retries_custom(): + with patch.object(MarketDataClient, "_setup_rate_limits"): + c = MarketDataClient(token="test", max_retries=5) + assert c.max_retries == 5 + + +def test_client_max_retries_zero(): + with patch.object(MarketDataClient, "_setup_rate_limits"): + c = MarketDataClient(token="test", max_retries=0) + assert c.max_retries == 0 + + +def test_client_max_retries_negative_raises(): + with pytest.raises(ValueError): + MarketDataClient(token="test", max_retries=-1) + + +def test_client_max_retries_zero_no_retry(respx_mock, monkeypatch): + monkeypatch.setattr("time.sleep", lambda *_: None) + headers = { + "x-api-ratelimit-limit": "100", + "x-api-ratelimit-remaining": "99", + "x-api-ratelimit-reset": "60", + "x-api-ratelimit-consumed": "1", + } + respx_mock.get("https://api.marketdata.app/user/").respond( + json={}, headers=headers, status_code=200 + ) + respx_mock.get("https://api.marketdata.app/v1/stocks/prices/").respond( + json={}, status_code=502 + ) + + c = MarketDataClient(token="test", max_retries=0) + setattr( + c, + "_extract_rate_limits", + lambda x: UserRateLimits( + requests_limit=100, + requests_remaining=99, + requests_reset=60, + requests_consumed=1, + ), + ) + + result = c.stocks.prices(symbols="AAPL") + assert isinstance(result, MarketDataClientErrorResult) + assert respx_mock.calls.call_count == 2 + + +def test_client_max_retries_one(respx_mock, monkeypatch): + monkeypatch.setattr("time.sleep", lambda *_: None) + headers = { + "x-api-ratelimit-limit": "100", + "x-api-ratelimit-remaining": "99", + "x-api-ratelimit-reset": "60", + "x-api-ratelimit-consumed": "1", + } + respx_mock.get("https://api.marketdata.app/user/").respond( + json={}, headers=headers, status_code=200 + ) + import time as _time + + _now = _time.time() + respx_mock.get("https://api.marketdata.app/status/").respond( + json={ + "service": ["/v1/stocks/bulkquotes/"], + "status": ["online"], + "online": [True], + "uptimePct30d": [100], + "uptimePct90d": [100], + "updated": [_now], + }, + headers=headers, + status_code=200, + ) + respx_mock.get("https://api.marketdata.app/v1/stocks/prices/").respond( + json={}, status_code=502 + ) + + c = MarketDataClient(token="test", max_retries=1) + setattr( + c, + "_extract_rate_limits", + lambda x: UserRateLimits( + requests_limit=100, + requests_remaining=99, + requests_reset=60, + requests_consumed=1, + ), + ) + + result = c.stocks.prices(symbols="AAPL") + assert isinstance(result, MarketDataClientErrorResult) + prices_calls = [c for c in respx_mock.calls if c.request.url.path == "/v1/stocks/prices/"] + assert len(prices_calls) == 2 + + def test_settings_extra_env_vars(): with patch.dict( os.environ, {"RANDOM_VAR_FOR_TESTING": "123", "MARKETDATA_TOKEN": "test_token"} diff --git a/src/tests/test_retry.py b/src/tests/test_retry.py new file mode 100644 index 0000000..46785de --- /dev/null +++ b/src/tests/test_retry.py @@ -0,0 +1,145 @@ +import datetime +from unittest.mock import MagicMock + +import pytest +from httpx import Headers, Request, Response + +from marketdata.exceptions import RequestError +from marketdata.retry import get_retry_adapter, parse_retry_after + + +def _make_retry_state(attempt_number: int, exc: Exception | None): + outcome = None + if exc is not None: + outcome = MagicMock() + outcome.exception.return_value = exc + return type( + "S", (), {"attempt_number": attempt_number, "outcome": outcome} + )() + + +def _retry_after_response(value: str) -> Response: + request = Request(method="GET", url="https://example.com") + return Response( + status_code=503, request=request, headers=Headers({"Retry-After": value}) + ) + + +def test_get_retry_adapter(client): + retry_adapter = get_retry_adapter( + attempts=4, + initial_delay=1.0, + exceptions=[], + logger=client.logger, + ) + assert retry_adapter is not None + assert retry_adapter.stop.max_attempt_number == 4 + assert retry_adapter.retry.exception_types == (Exception,) + assert retry_adapter.reraise == False + + state = _make_retry_state(attempt_number=1, exc=None) + assert retry_adapter.wait(state) == 1.0 + state.attempt_number = 2 + assert retry_adapter.wait(state) == 2.0 + state.attempt_number = 3 + assert retry_adapter.wait(state) == 4.0 + + +@pytest.mark.parametrize( + "value,expected", + [ + (None, None), + ("", None), + (" ", None), + ("0", 0.0), + ("120", 120.0), + ("3.5", 3.5), + ("-5", 0.0), + ("not-a-date", None), + ("nan", None), + ("inf", None), + ("-inf", None), + ], +) +def test_parse_retry_after_seconds(value, expected): + assert parse_retry_after(value) == expected + + +def test_parse_retry_after_http_date_future(): + future = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=60) + header = future.strftime("%a, %d %b %Y %H:%M:%S GMT") + result = parse_retry_after(header) + assert result is not None + assert 55 < result <= 60 + + +def test_parse_retry_after_http_date_past(): + assert parse_retry_after("Wed, 21 Oct 1995 07:28:00 GMT") == 0.0 + + +def test_parse_retry_after_naive_asctime_treated_as_utc(): + future = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=120) + header = future.strftime("%a %b %d %H:%M:%S %Y") + result = parse_retry_after(header) + assert result is not None + assert 110 < result <= 120 + + +def test_compute_wait_retry_after_overrides_exponential(client): + retry_adapter = get_retry_adapter( + attempts=4, + initial_delay=1.0, + exceptions=[RequestError], + logger=client.logger, + ) + exc = RequestError( + "boom", + request=Request(method="GET", url="https://example.com"), + response=_retry_after_response("7"), + ) + state = _make_retry_state(attempt_number=2, exc=exc) + assert retry_adapter.wait(state) == 7.0 + + +def test_compute_wait_no_retry_after_falls_back_to_exponential(client): + retry_adapter = get_retry_adapter( + attempts=4, + initial_delay=1.0, + exceptions=[RequestError], + logger=client.logger, + ) + request = Request(method="GET", url="https://example.com") + exc = RequestError( + "boom", + request=request, + response=Response(status_code=503, request=request), + ) + state = _make_retry_state(attempt_number=3, exc=exc) + assert retry_adapter.wait(state) == 4.0 + + +def test_compute_wait_invalid_retry_after_falls_back(client): + retry_adapter = get_retry_adapter( + attempts=4, + initial_delay=1.0, + exceptions=[RequestError], + logger=client.logger, + ) + exc = RequestError( + "boom", + request=Request(method="GET", url="https://example.com"), + response=_retry_after_response("garbage"), + ) + state = _make_retry_state(attempt_number=1, exc=exc) + assert retry_adapter.wait(state) == 1.0 + + +def test_compute_wait_exception_without_response_falls_back(client): + retry_adapter = get_retry_adapter( + attempts=4, + initial_delay=1.0, + exceptions=[Exception], + logger=client.logger, + ) + state = _make_retry_state(attempt_number=2, exc=Exception("network down")) + assert retry_adapter.wait(state) == 2.0 diff --git a/uv.lock b/uv.lock index 42a730b..c2ec708 100644 --- a/uv.lock +++ b/uv.lock @@ -338,7 +338,7 @@ wheels = [ [[package]] name = "marketdata-sdk-py" -version = "1.1.0" +version = "1.2.0" source = { editable = "." } dependencies = [ { name = "httpx" },