Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/marketdata/input_types/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class OptionsChainInput(BaseInputType):
description="The expiration date to filter by", default=None
)
days_to_expiration: int | None = Field(
description="The number of days to expiration to filter by", default=None
description="The number of days to expiration to filter by",
alias="dte",
default=None,
)
from_date: datetime.date | str | None = Field(
description="The start date to fetch options chain for",
Expand Down
68 changes: 68 additions & 0 deletions src/tests/test_input_types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import datetime
import importlib
import pkgutil
from pathlib import Path

import pytest
from pydantic import Field, model_validator

import marketdata.input_types as input_types_pkg
from marketdata.exceptions import MinMaxDateValidationError
from marketdata.input_types.base import (
BaseInputType,
OutputFormat,
UserUniversalAPIParams,
)
from marketdata.internal_settings import GLOBAL_EXCLUDED_PARAMS


class DummyInput(BaseInputType):
Expand All @@ -22,6 +26,70 @@ def validate_input(self) -> "DummyInput":
return self


def _all_input_models() -> list[type[BaseInputType]]:
"""Return every concrete BaseInputType subclass defined in the SDK.

Imports each input_types submodule first so all subclasses are registered.
"""
for module_info in pkgutil.iter_modules(input_types_pkg.__path__):
importlib.import_module(f"{input_types_pkg.__name__}.{module_info.name}")

seen: set[type[BaseInputType]] = set()
stack: list[type[BaseInputType]] = [BaseInputType]
while stack:
for sub in stack.pop().__subclasses__():
if sub not in seen:
seen.add(sub)
stack.append(sub)
# Only audit models shipped in the SDK, not test-only helper subclasses.
return sorted(
(c for c in seen if c.__module__.startswith(input_types_pkg.__name__)),
key=lambda c: c.__name__,
)


def _snake_case_fields() -> list[tuple[str, str, "object"]]:
"""All (model_name, field_name, field) tuples for fields whose Python name
differs from a bare API parameter (i.e. contain an underscore)."""
cases = []
for model in _all_input_models():
for field_name, field in model.model_fields.items():
if "_" in field_name:
cases.append((model.__name__, field_name, field))
return cases


_SNAKE_CASE_FIELDS = _snake_case_fields()


def test_snake_case_field_discovery_is_not_empty():
# Guard: if discovery silently breaks, the parametrized test below would
# vacuously pass. days_to_expiration alone guarantees at least one case.
assert _SNAKE_CASE_FIELDS


@pytest.mark.parametrize(
"model_name, field_name, field",
_SNAKE_CASE_FIELDS,
ids=[f"{m}.{f}" for m, f, _ in _SNAKE_CASE_FIELDS],
)
def test_snake_case_input_fields_have_api_alias(model_name, field_name, field):
"""Every multi-word input field must be sent under an explicit API alias.

The URL builder serializes with ``by_alias=True``; a snake_case field
without an alias would leak its Python name to the wire (see issue #30,
``days_to_expiration`` -> ``dte``). Fields that are never serialized to the
query string are exempted via GLOBAL_EXCLUDED_PARAMS.
"""
if field_name in GLOBAL_EXCLUDED_PARAMS:
return

assert field.alias and field.alias != field_name, (
f"{model_name}.{field_name} has no API alias; it would be sent to the "
f"API as '{field_name}'. Add an alias or exclude it."
)


def test_base_input_type_min_max_validation():
with pytest.raises(MinMaxDateValidationError):
DummyInput(min_param="2025-01-01", max_param="2024-01-01")
Expand Down
19 changes: 19 additions & 0 deletions src/tests/test_options_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,22 @@ def test_options_chain_input_date_range_aliases_on_wire(load_json, respx_mock, c
assert params.get("to") == "2026-04-18"
assert params.get("from_date") is None
assert params.get("to_date") is None


def test_options_chain_input_days_to_expiration_alias_on_wire(
load_json, respx_mock, client
):
mock_data = load_json("options_chain_response_200")
respx_mock.get("https://api.marketdata.app/v1/options/chain/AAPL/").respond(
json=mock_data, status_code=200
)

client.options.chain(
"AAPL",
days_to_expiration=30,
output_format=OutputFormat.INTERNAL,
)

params = respx_mock.calls.last.request.url.params
assert params.get("dte") == "30"
assert params.get("days_to_expiration") is None
Loading