Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/marketdata/output_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import types
from abc import ABC, abstractmethod
from dataclasses import is_dataclass
from datetime import date, datetime
Expand Down Expand Up @@ -29,7 +30,9 @@ def _type_includes(self, field_type: Any, target: type) -> bool:
args = get_args(field_type)
if origin in (list, list, Iterable):
return any(self._type_includes(arg, target) for arg in args)
if origin is Union:
# Handle both typing.Union[X, None] and the PEP 604 `X | None` form,
# whose origin is types.UnionType rather than typing.Union.
if origin is Union or origin is types.UnionType:
return any(
self._type_includes(arg, target)
for arg in args
Expand Down
5 changes: 3 additions & 2 deletions src/marketdata/output_types/options_expirations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
class OptionsExpirations:
s: str
expirations: list[datetime.datetime]
updated: datetime.datetime
updated: datetime.datetime | None = None

def __post_init__(self):
self.updated = format_timestamp(self.updated)
if self.updated is not None:
self.updated = format_timestamp(self.updated)
self.expirations = [
format_timestamp(expiration) for expiration in self.expirations
]
Expand Down
7 changes: 6 additions & 1 deletion src/marketdata/resources/options/expirations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,13 @@ def expirations(
if user_universal_params.output_format == OutputFormat.DATAFRAME:
data = response.json()
handler = get_dataframe_output_handler()
# When the user explicitly filters columns we must not force
# "expirations" into the index: doing so when it is the only requested
# column would promote all data into the index and leave an apparently
# empty DataFrame.
index_columns = [] if user_universal_params.columns else ["expirations"]
return handler(data, output_model, user_universal_params).get_result(
index_columns=["expirations"]
index_columns=index_columns
)

elif user_universal_params.output_format == OutputFormat.INTERNAL:
Expand Down
125 changes: 125 additions & 0 deletions src/tests/test_options_expirations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
from unittest.mock import patch

import pandas as pd
import pytz

from marketdata.input_types.base import (
Expand Down Expand Up @@ -191,6 +192,130 @@ def test_get_options_expirations_status_offline(load_json, respx_mock, client):
assert isinstance(expirations, MarketDataClientErrorResult)


def test_options_expirations_optional_updated():
"""Issue #23: the `updated` field must be optional so partial API
responses (e.g. when filtering columns) don't raise.
"""
instance = OptionsExpirations(
s="ok",
expirations=[1764910800],
updated=None,
)
assert instance.updated is None
assert isinstance(str(instance), str)


def test_get_options_expirations_columns_filter_dataframe_pandas(
respx_mock, client
):
"""Issue #23: requesting `columns=["expirations"]` makes the API return
only that column. The result must NOT be an empty DataFrame with the data
silently moved into the index.
"""
with patch(
"marketdata.output_handlers.DATAFRAME_HANDLERS_PRIORITY",
["pandas"],
):
expiration_timestamps = [1764910800, 1765515600, 1766120400]
# Server-side column filtering: only the requested column comes back.
partial_data = {
"s": "ok",
"expirations": expiration_timestamps,
}
respx_mock.get(
"https://api.marketdata.app/v1/options/expirations/AAPL/"
).respond(
json=partial_data,
status_code=200,
)

df = client.options.expirations(
symbol="AAPL",
output_format=OutputFormat.DATAFRAME,
columns=["expirations"],
)

# The data must stay as an "expirations" column on a default
# RangeIndex, not be silently promoted into the index.
expected_df = pd.DataFrame(
{
"expirations": pd.to_datetime(
expiration_timestamps, unit="s", utc=True
).tz_convert(ET)
}
)
pd.testing.assert_frame_equal(df, expected_df)


def test_get_options_expirations_columns_filter_dataframe_polars(
respx_mock, client
):
"""Issue #23 (regression guard for polars): filtering by a single column
must keep the data accessible as a column.
"""
with patch(
"marketdata.output_handlers.DATAFRAME_HANDLERS_PRIORITY",
["polars"],
):
expiration_timestamps = [1764910800, 1765515600, 1766120400]
partial_data = {
"s": "ok",
"expirations": expiration_timestamps,
}
respx_mock.get(
"https://api.marketdata.app/v1/options/expirations/AAPL/"
).respond(
json=partial_data,
status_code=200,
)

df = client.options.expirations(
symbol="AAPL",
output_format=OutputFormat.DATAFRAME,
columns=["expirations"],
)

# A single "expirations" column holding the timestamps converted to
# US/Eastern datetimes, with nothing dropped.
expected_expirations = [
datetime.datetime.fromtimestamp(ts, tz=ET)
for ts in expiration_timestamps
]
assert df.columns == ["expirations"]
assert df["expirations"].to_list() == expected_expirations


def test_get_options_expirations_partial_response_internal(respx_mock, client):
"""Issue #23: an INTERNAL response missing the `updated` field must parse
successfully instead of failing and returning an error result.
"""
expiration_timestamps = [1764910800, 1765515600, 1766120400]
partial_data = {
"s": "ok",
"expirations": expiration_timestamps,
}
respx_mock.get(
"https://api.marketdata.app/v1/options/expirations/AAPL/"
).respond(
json=partial_data,
status_code=200,
)

expirations = client.options.expirations(
symbol="AAPL", output_format=OutputFormat.INTERNAL
)

# The partial response parses, with timestamps converted to US/Eastern
# datetimes and the absent `updated` field left as None.
expected_expirations = [
datetime.datetime.fromtimestamp(ts, tz=ET) for ts in expiration_timestamps
]
assert isinstance(expirations, OptionsExpirations)
assert expirations.s == "ok"
assert expirations.expirations == expected_expirations
assert expirations.updated is None


def test_get_options_expirations_response_200_csv(respx_mock, client):
respx_mock.get("https://api.marketdata.app/v1/options/expirations/AAPL/").respond(
text="AS RECEIVED FROM API",
Expand Down
20 changes: 20 additions & 0 deletions src/tests/test_output_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DummySchemaOptionalDates:
updated: Union[datetime.datetime, None] = None


@dataclass
class DummySchemaNonDateContainer:
mapping: dict[str, int]
updated: datetime.datetime


class PassthroughHandler(BaseOutputHandler):
def _get_result(self, *args, **kwargs):
return {"ok": True}
Expand Down Expand Up @@ -115,6 +121,20 @@ def test_base_output_handler_date_columns_from_schema():
assert handler._get_datetime_columns() == ["updated"]


def test_base_output_handler_ignores_non_date_container_fields():
"""A field whose type origin is neither a sequence nor a union (e.g. a
`dict`) must not be treated as a date column — exercises the fall-through
return in `_type_includes`.
"""
handler = PassthroughHandler(
data={},
output_schema=DummySchemaNonDateContainer,
user_universal_params=_make_params(),
)
assert handler._get_date_columns() == []
assert handler._get_datetime_columns() == ["updated"]


def test_base_output_handler_non_dataclass_schema():
handler = PassthroughHandler(
data={},
Expand Down
Loading