Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions src/marketdata/api_error.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from functools import wraps
from logging import DEBUG
from typing import TYPE_CHECKING, Callable

from tenacity import before_sleep_log

from marketdata.api_status import API_STATUS_DATA, APIStatusResult
from marketdata.exceptions import RequestError
from marketdata.internal_settings import (
MAX_RETRY_ATTEMPTS,
MAX_RETRY_BACKOFF,
MIN_RETRY_BACKOFF,
RETRY_BACKOFF,
)
from marketdata.internal_settings import INITIAL_RETRY_DELAY
from marketdata.resources.base import BaseResource
from marketdata.retry import get_retry_adapter

Expand All @@ -25,25 +23,22 @@ def wrapper(*args, **kwargs):
resource: BaseResource = args[0]
client: "MarketDataClient" = resource.client
logger = client.logger
log_before_sleep = before_sleep_log(logger, log_level=DEBUG)

try:
return func(*args, **kwargs)
except RequestError as e:
if API_STATUS_DATA.should_refresh:
API_STATUS_DATA.refresh(client)

def _status_check_before_sleep(retry_state):
status = API_STATUS_DATA.get_api_status(client, service)
if status in (APIStatusResult.ONLINE, APIStatusResult.UNKNOWN):
retry_adapter = get_retry_adapter(
attempts=MAX_RETRY_ATTEMPTS,
backoff=RETRY_BACKOFF,
exceptions=[RequestError],
logger=logger,
reraise=True,
min_backoff=MIN_RETRY_BACKOFF,
max_backoff=MAX_RETRY_BACKOFF,
)
return retry_adapter(func, *args, **kwargs)
raise e
if status == APIStatusResult.OFFLINE:
raise retry_state.outcome.exception()
log_before_sleep(retry_state)

retry_adapter = get_retry_adapter(
attempts=client.max_retries + 1,
initial_delay=INITIAL_RETRY_DELAY,
exceptions=[RequestError],
logger=logger,
reraise=True,
before_sleep=_status_check_before_sleep,
)
return retry_adapter(func, *args, **kwargs)

return wrapper
99 changes: 71 additions & 28 deletions src/marketdata/api_status.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import datetime
import logging
import threading
from enum import Enum
from typing import TYPE_CHECKING

from marketdata.exceptions import BadStatusCodeError, InvalidStatusDataError
from marketdata.internal_settings import REFRESH_API_STATUS_INTERVAL
from marketdata.internal_settings import (
CACHE_VALIDITY_INTERVAL,
REFRESH_API_STATUS_INTERVAL,
)

if TYPE_CHECKING:
from marketdata.client import MarketDataClient
Expand All @@ -18,54 +22,67 @@ class APIStatusResult(str, Enum):

class APIStatusData:
def __init__(self):
self._lock = threading.Lock()
self._refresh_in_flight = False
self._refresh_thread: threading.Thread | None = None
self._last_refresh_at: datetime.datetime | None = None
self.service = []
self.status = []
self.online = []
self.uptimePct30d = []
self.uptimePct90d = []
self.updated = []

def update(self, data: dict):
try:
self.service = data["service"]
self.status = data["status"]
self.online = data["online"]
self.uptimePct30d = data["uptimePct30d"]
self.uptimePct90d = data["uptimePct90d"]
self.updated = data["updated"]
new_service = data["service"]
new_status = data["status"]
new_online = data["online"]
except KeyError as e:
raise InvalidStatusDataError(f"Invalid status data: {e}") from e
with self._lock:
self.service = new_service
self.status = new_status
self.online = new_online
self._last_refresh_at = datetime.datetime.now()

@property
def last_updated(self) -> datetime.datetime:
if not self.updated:
return datetime.datetime(1970, 1, 1)
return datetime.datetime.fromtimestamp(min(self.updated))
def cache_age(self) -> datetime.timedelta:
if self._last_refresh_at is None:
return datetime.timedelta.max
return datetime.datetime.now() - self._last_refresh_at

@property
def should_refresh(self) -> bool:
return datetime.datetime.now() - self.last_updated > REFRESH_API_STATUS_INTERVAL
return self.cache_age >= REFRESH_API_STATUS_INTERVAL

@property
def is_cache_stale(self) -> bool:
return self.cache_age >= CACHE_VALIDITY_INTERVAL

def get_api_status(
self, client: "MarketDataClient", service: str
) -> APIStatusResult:
client.logger.debug(f"Checking if service {service} is online")
if self.should_refresh and not self.refresh(client):
return APIStatusResult.UNKNOWN

if service not in self.service:
client.logger.error(f"Service {service} not found in API status")
if self.is_cache_stale:
self._trigger_async_refresh(client)
return APIStatusResult.UNKNOWN

service_index = self.service.index(service)
if self.status[service_index] != APIStatusResult.ONLINE:
client.logger.error(f"Service {service} is offline")
return APIStatusResult.OFFLINE
if not self.online[service_index]:
client.logger.error(f"Service {service} is not online")
return APIStatusResult.OFFLINE
client.logger.debug(f"Service {service} is online")
return APIStatusResult.ONLINE
if self.should_refresh:
self._trigger_async_refresh(client)

with self._lock:
if service not in self.service:
client.logger.error(f"Service {service} not found in API status")
return APIStatusResult.UNKNOWN

service_index = self.service.index(service)
if self.status[service_index] != APIStatusResult.ONLINE:
client.logger.error(f"Service {service} is offline")
return APIStatusResult.OFFLINE
if not self.online[service_index]:
client.logger.error(f"Service {service} is not online")
return APIStatusResult.OFFLINE
client.logger.debug(f"Service {service} is online")
return APIStatusResult.ONLINE

def refresh(self, client: "MarketDataClient") -> bool:
try:
Expand All @@ -86,5 +103,31 @@ def refresh(self, client: "MarketDataClient") -> bool:
client.logger.error(f"Failed to refresh API status: {e}")
return False

def _trigger_async_refresh(self, client: "MarketDataClient") -> None:
with self._lock:
if self._refresh_in_flight:
return
self._refresh_in_flight = True

try:
thread = threading.Thread(
target=self._async_refresh, args=(client,), daemon=True
)
self._refresh_thread = thread
thread.start()
except Exception:
with self._lock:
self._refresh_in_flight = False
raise

def _async_refresh(self, client: "MarketDataClient") -> None:
try:
self.refresh(client)
except Exception:
client.logger.exception("Async status refresh failed")
finally:
with self._lock:
self._refresh_in_flight = False


API_STATUS_DATA = APIStatusData()
11 changes: 10 additions & 1 deletion src/marketdata/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from marketdata.input_types.base import UserUniversalAPIParams
from marketdata.internal_settings import (
HTTP_TIMEOUT,
MAX_RETRY_ATTEMPTS,
NO_TOKEN_VALUE,
RETRY_STATUS_CODES,
)
Expand All @@ -23,8 +24,16 @@

class MarketDataClient:

def __init__(self, token: str = None, logger: Logger = None):
def __init__(
self,
token: str = None,
logger: Logger = None,
max_retries: int = MAX_RETRY_ATTEMPTS,
):
if max_retries < 0:
raise ValueError("max_retries must be >= 0")
self.token = token or settings.marketdata_token
self.max_retries = max_retries
self.library_version = version("marketdata-sdk-py")
self.library_user_agent = self._get_user_agent()

Expand Down
5 changes: 2 additions & 3 deletions src/marketdata/internal_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ class NoTokenValueType:

MAX_CONCURRENT_REQUESTS = 50
MAX_RETRY_ATTEMPTS = 3
RETRY_BACKOFF = 0.5
INITIAL_RETRY_DELAY = 1.0
RETRY_STATUS_CODES = lambda x: x > 500
HTTP_TIMEOUT = 60
MIN_RETRY_BACKOFF = 0.5
MAX_RETRY_BACKOFF = 5
VALID_STATUS_CODES = [200, 203]
GLOBAL_EXCLUDED_PARAMS = ["output_format", "filename"]
REFRESH_API_STATUS_INTERVAL = datetime.timedelta(minutes=4, seconds=30)
CACHE_VALIDITY_INTERVAL = datetime.timedelta(minutes=5)
ALLOWED_POSITIONAL_PARAMS = ["symbol", "symbols", "lookup"]
DATAFRAME_HANDLERS_PRIORITY = ["pandas", "polars"]
NO_TOKEN_VALUE = NoTokenValueType()
49 changes: 42 additions & 7 deletions src/marketdata/retry.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,66 @@
import math
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from logging import DEBUG, Logger
from typing import Callable

from tenacity import (
Retrying,
before_sleep_log,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)


def parse_retry_after(value: str | None) -> float | None:
if not value:
return None
value = value.strip()
try:
parsed = float(value)
except ValueError:
parsed = None
if parsed is not None:
if not math.isfinite(parsed):
return None
return max(0.0, parsed)
try:
target = parsedate_to_datetime(value)
except (TypeError, ValueError):
return None
if target.tzinfo is None:
target = target.replace(tzinfo=timezone.utc)
return max(0.0, (target - datetime.now(timezone.utc)).total_seconds())


def get_retry_adapter(
attempts: int,
backoff: float,
initial_delay: float,
logger: Logger,
exceptions: list[Exception] = None,
reraise: bool = False,
min_backoff: float = 0.5,
max_backoff: float = 5,
before_sleep: Callable = None,
) -> Retrying:

if not exceptions:
exceptions = [Exception]

def _compute_wait(retry_state) -> float:
exc = retry_state.outcome.exception() if retry_state.outcome else None
if exc is not None:
response = getattr(exc, "response", None)
if response is not None:
retry_after = parse_retry_after(
response.headers.get("Retry-After")
)
if retry_after is not None:
return retry_after
return initial_delay * 2 ** (retry_state.attempt_number - 1)

return Retrying(
stop=stop_after_attempt(attempts),
wait=wait_exponential(multiplier=backoff, min=min_backoff, max=max_backoff),
retry=retry_if_exception_type(*exceptions),
wait=_compute_wait,
retry=retry_if_exception_type(tuple(exceptions)),
reraise=reraise,
before_sleep=before_sleep_log(logger, log_level=DEBUG),
before_sleep=before_sleep or before_sleep_log(logger, log_level=DEBUG),
)
10 changes: 10 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from marketdata.api_status import API_STATUS_DATA
from marketdata.client import MarketDataClient
from marketdata.types import UserRateLimits

Expand All @@ -25,6 +26,15 @@ def chdir(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)


@pytest.fixture(autouse=True)
def _reset_api_status_data():
API_STATUS_DATA.__init__()
yield
if API_STATUS_DATA._refresh_thread is not None:
API_STATUS_DATA._refresh_thread.join(timeout=2)
API_STATUS_DATA.__init__()


@pytest.fixture
def client(respx_mock):

Expand Down
Loading
Loading