Skip to content

Commit b00c4ca

Browse files
authored
feat(models): add backtesting helpers (#2)
1 parent 61a5f2b commit b00c4ca

6 files changed

Lines changed: 621 additions & 12 deletions

File tree

.github/workflows/tests.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
jobs:
99
test:
1010
runs-on: ubuntu-latest
11+
strategy:
12+
matrix:
13+
extras: ["dev", "dev,backtest"]
1114

1215
steps:
1316
- name: Checkout repository
@@ -19,8 +22,8 @@ jobs:
1922
activate-environment: true
2023
python-version: "3.13"
2124

22-
- name: Install dependencies
23-
run: uv pip install -e ".[dev]"
25+
- name: Install dependencies with extras
26+
run: uv pip install -e ".[${{ matrix.extras }}]"
2427

2528
- name: Run tests with pytest
2629
run: uv run pytest

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ You can install it directly from the GitHub repository.
1212
pip install git+https://github.com/ZeroGuacamole/mode-python-sdk.git
1313
```
1414

15+
Install with backtesting helpers (pandas/numpy) via extras:
16+
17+
```bash
18+
pip install "git+https://github.com/ZeroGuacamole/mode-python-sdk.git#egg=mode-sdk[backtest]"
19+
```
20+
1521
## Quickstart
1622

1723
Here's a quick example of how to use the client to fetch historical data.
@@ -60,6 +66,45 @@ except ModeAPIError as e:
6066

6167
```
6268

69+
### Helpers
70+
71+
The models include utilities commonly used in research/backtesting pipelines.
72+
73+
1. Convert historical data to a pandas DataFrame (UTC index):
74+
75+
```python
76+
from mode_sdk.client import ModeAPIClient
77+
78+
client = ModeAPIClient()
79+
hist = client.market_data.get_historical_data("AAPL", "2024-01-01", "2024-01-31", "daily")
80+
81+
# Requires: pip install pandas
82+
df = hist.to_dataframe()
83+
print(df.head())
84+
```
85+
86+
2. Convert historical data to NumPy arrays for vectorized processing:
87+
88+
```python
89+
# Requires: pip install numpy
90+
ts, open_, high, low, close, volume = hist.to_numpy()
91+
```
92+
93+
3. Quote convenience properties:
94+
95+
```python
96+
quotes = client.market_data.get_quotes(["AAPL"]).quotes
97+
q = quotes["AAPL"]
98+
print(q.mid_price, q.spread)
99+
```
100+
101+
### Data validation and normalization
102+
103+
- Symbols are normalized to uppercase in `Asset` and `HistoricalDataResponse`.
104+
- Timestamps are normalized to UTC in all models that include time fields.
105+
- OHLCV values are validated (non-negative; high/low consistency) for `HistoricalDataPoint`.
106+
- `Quote` validation ensures non-negative prices and `ask >= bid` when both are present.
107+
63108
## Development
64109

65110
1. Clone the repository.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ classifiers = [
1414
"Operating System :: OS Independent",
1515
]
1616
dependencies = [
17-
"requests==2.32.4",
17+
"requests==2.32.5",
1818
"pydantic==2.11.7",
1919
"python-dotenv",
2020
"types-requests>=2.32.4",
2121
]
2222

2323
[project.optional-dependencies]
2424
dev = ["pytest", "pytest-mock", "requests-mock", "black", "ruff", "mypy"]
25+
backtest = ["pandas>=2.3.2", "numpy>=2.3.2"]
2526

2627
[tool.ruff]
2728
line-length = 120

src/mode_sdk/models.py

Lines changed: 195 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from datetime import datetime
1+
from datetime import datetime, timezone
2+
import importlib
23
from enum import Enum
34
from typing import Any, Dict, List, Optional
45

5-
from pydantic import BaseModel, Field, field_validator, ValidationInfo
6+
from pydantic import BaseModel, Field, field_validator, ValidationInfo, model_validator
67

78

89
class AssetType(str, Enum):
@@ -53,6 +54,20 @@ def validate_details(cls, v: Any, info: ValidationInfo) -> Any:
5354

5455
return v
5556

57+
@field_validator("symbol", mode="after")
58+
@classmethod
59+
def normalize_symbol(cls, v: str) -> str:
60+
"""Normalize symbols to uppercase for consistency across models."""
61+
return v.upper()
62+
63+
@field_validator("last_updated", mode="after")
64+
@classmethod
65+
def normalize_last_updated(cls, v: datetime) -> datetime:
66+
"""Ensure timestamps are timezone-aware and normalized to UTC."""
67+
if v.tzinfo is None:
68+
return v.replace(tzinfo=timezone.utc)
69+
return v.astimezone(timezone.utc)
70+
5671

5772
class Quote(BaseModel):
5873
"""Represents a real-time quote for a symbol."""
@@ -70,13 +85,59 @@ class Quote(BaseModel):
7085
previous_close: Optional[float] = Field(default=None, alias="previousClose")
7186
open: Optional[float] = None
7287

88+
@field_validator("timestamp", mode="after")
89+
@classmethod
90+
def normalize_timestamp(cls, v: datetime) -> datetime:
91+
"""Ensure quote timestamps are timezone-aware (UTC)."""
92+
if v.tzinfo is None:
93+
return v.replace(tzinfo=timezone.utc)
94+
return v.astimezone(timezone.utc)
95+
96+
@model_validator(mode="after")
97+
def validate_prices(self) -> "Quote":
98+
"""Basic sanity checks for price fields used in backtesting."""
99+
if self.price is not None and self.price < 0:
100+
raise ValueError("price must be non-negative")
101+
if self.bid is not None and self.bid < 0:
102+
raise ValueError("bid must be non-negative")
103+
if self.ask is not None and self.ask < 0:
104+
raise ValueError("ask must be non-negative")
105+
if self.bid is not None and self.ask is not None and self.ask < self.bid:
106+
raise ValueError("ask must be greater than or equal to bid")
107+
return self
108+
109+
@property
110+
def mid_price(self) -> float:
111+
"""Return mid price if bid/ask present, otherwise fall back to last price."""
112+
if self.bid is not None and self.ask is not None:
113+
return (self.bid + self.ask) / 2.0
114+
return self.price
115+
116+
@property
117+
def spread(self) -> Optional[float]:
118+
"""Return bid/ask spread if available."""
119+
if self.bid is not None and self.ask is not None:
120+
return self.ask - self.bid
121+
return None
122+
73123

74124
class QuoteResponse(BaseModel):
75125
"""Represents the structure of the quotes API response."""
76126

77127
quotes: Dict[str, Quote]
78128
errors: Dict[str, str]
79129

130+
@model_validator(mode="after")
131+
def validate_quote_keys(self) -> "QuoteResponse":
132+
"""Ensure the mapping keys align with each nested quote's symbol when present."""
133+
for key, quote in self.quotes.items():
134+
# Only validate when symbol is present; API may return partials
135+
if quote.symbol and key.upper() != quote.symbol.upper():
136+
raise ValueError(
137+
f"quotes key '{key}' does not match nested symbol '{quote.symbol}'"
138+
)
139+
return self
140+
80141

81142
class HistoricalDataPoint(BaseModel):
82143
"""Represents a single OHLCV data point."""
@@ -88,9 +149,141 @@ class HistoricalDataPoint(BaseModel):
88149
close: Optional[float] = None
89150
volume: Optional[int] = None
90151

152+
@field_validator("timestamp", mode="after")
153+
@classmethod
154+
def normalize_timestamp(cls, v: datetime) -> datetime:
155+
"""Ensure bar timestamps are timezone-aware (UTC)."""
156+
if v.tzinfo is None:
157+
return v.replace(tzinfo=timezone.utc)
158+
return v.astimezone(timezone.utc)
159+
160+
@model_validator(mode="after")
161+
def validate_ohlcv(self) -> "HistoricalDataPoint":
162+
"""Sanity checks for OHLCV used in backtesting pipelines."""
163+
# Non-negativity
164+
for name in ("open", "high", "low", "close"):
165+
value = getattr(self, name)
166+
if value is not None and value < 0:
167+
raise ValueError(f"{name} must be non-negative")
168+
if self.volume is not None and self.volume < 0:
169+
raise ValueError("volume must be non-negative")
170+
171+
# High/low consistency with other provided fields
172+
candidates_for_high: List[float] = [
173+
v for v in [self.open, self.close, self.low] if v is not None
174+
]
175+
if (
176+
self.high is not None
177+
and candidates_for_high
178+
and self.high < max(candidates_for_high)
179+
):
180+
raise ValueError("high must be >= max(open, close, low) when provided")
181+
candidates_for_low: List[float] = [
182+
v for v in [self.open, self.close, self.high] if v is not None
183+
]
184+
if (
185+
self.low is not None
186+
and candidates_for_low
187+
and self.low > min(candidates_for_low)
188+
):
189+
raise ValueError("low must be <= min(open, close, high) when provided")
190+
191+
return self
192+
91193

92194
class HistoricalDataResponse(BaseModel):
93195
"""Represents the structure of the historical data API response."""
94196

95197
symbol: str
96198
data_points: List[HistoricalDataPoint] = Field(..., alias="dataPoints")
199+
200+
@field_validator("symbol", mode="after")
201+
@classmethod
202+
def normalize_symbol(cls, v: str) -> str:
203+
return v.upper()
204+
205+
def to_records(self) -> List[Dict[str, Any]]:
206+
"""Return the historical data as a list of dictionaries."""
207+
return [
208+
{
209+
"timestamp": point.timestamp,
210+
"open": point.open,
211+
"high": point.high,
212+
"low": point.low,
213+
"close": point.close,
214+
"volume": point.volume,
215+
}
216+
for point in self.data_points
217+
]
218+
219+
def to_dataframe(self): # type: ignore[override]
220+
"""Convert the historical data to a pandas DataFrame (if pandas is installed).
221+
222+
Returns a DataFrame indexed by UTC timestamps with columns: open, high, low,
223+
close, volume. The frame is sorted by index and duplicate timestamps are
224+
collapsed keeping the last occurrence.
225+
"""
226+
try:
227+
pd = importlib.import_module("pandas")
228+
except Exception as exc:
229+
raise ImportError(
230+
"pandas is required for to_dataframe(); install with 'pip install pandas'"
231+
) from exc
232+
233+
records = self.to_records()
234+
if not records:
235+
return pd.DataFrame(
236+
columns=["open", "high", "low", "close", "volume"]
237+
).astype(
238+
{
239+
"open": "float64",
240+
"high": "float64",
241+
"low": "float64",
242+
"close": "float64",
243+
"volume": "float64",
244+
}
245+
)
246+
247+
frame = pd.DataFrame.from_records(records)
248+
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
249+
frame.set_index("timestamp", inplace=True)
250+
frame.sort_index(inplace=True)
251+
frame = frame[~frame.index.duplicated(keep="last")]
252+
return frame
253+
254+
def to_numpy(self): # type: ignore[override]
255+
"""Return numpy arrays (timestamps, open, high, low, close, volume).
256+
257+
Arrays are suitable for fast vectorized backtests. Requires numpy.
258+
"""
259+
try:
260+
np = importlib.import_module("numpy")
261+
except Exception as exc:
262+
raise ImportError(
263+
"numpy is required for to_numpy(); install with 'pip install numpy'"
264+
) from exc
265+
266+
points = self.data_points
267+
n = len(points)
268+
ts = np.empty(n, dtype="datetime64[ns]")
269+
open_ = np.full(n, np.nan)
270+
high = np.full(n, np.nan)
271+
low = np.full(n, np.nan)
272+
close = np.full(n, np.nan)
273+
vol = np.full(n, np.nan)
274+
275+
for i, p in enumerate(points):
276+
dt_utc = p.timestamp.astimezone(timezone.utc).replace(tzinfo=None)
277+
ts[i] = np.datetime64(dt_utc, "ns")
278+
if p.open is not None:
279+
open_[i] = p.open
280+
if p.high is not None:
281+
high[i] = p.high
282+
if p.low is not None:
283+
low[i] = p.low
284+
if p.close is not None:
285+
close[i] = p.close
286+
if p.volume is not None:
287+
vol[i] = p.volume
288+
289+
return ts, open_, high, low, close, vol

0 commit comments

Comments
 (0)