Skip to content

Commit 9c2c0b6

Browse files
committed
Route bars through Rust core and PyO3 bindings
1 parent 7dc3f78 commit 9c2c0b6

4 files changed

Lines changed: 204 additions & 61 deletions

File tree

crates/openquant/src/data_structures.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub enum StandardBarType {
2323
/// Bar output with OHLCV-like fields.
2424
#[derive(Debug, Clone, PartialEq)]
2525
pub struct StandardBar {
26+
pub start_timestamp: NaiveDateTime,
2627
pub timestamp: NaiveDateTime,
2728
pub open: f64,
2829
pub high: f64,
@@ -211,6 +212,7 @@ fn build_bar(trades: &[Trade]) -> StandardBar {
211212

212213
let open = trades.first().expect("non-empty slice").price;
213214
let close = trades.last().expect("non-empty slice").price;
215+
let start_timestamp = trades.first().expect("non-empty slice").timestamp;
214216
let timestamp = trades.last().expect("non-empty slice").timestamp;
215217
let (high, low) = trades.iter().fold((f64::NEG_INFINITY, f64::INFINITY), |(h, l), trade| {
216218
(h.max(trade.price), l.min(trade.price))
@@ -222,6 +224,7 @@ fn build_bar(trades: &[Trade]) -> StandardBar {
222224
});
223225

224226
StandardBar {
227+
start_timestamp,
225228
timestamp,
226229
open,
227230
high,

crates/pyopenquant/src/lib.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use nalgebra::DMatrix;
2+
use openquant::data_structures::{standard_bars, time_bars, StandardBarType, Trade};
23
use openquant::filters::Threshold;
34
use openquant::pipeline::{
45
run_mid_frequency_pipeline, ResearchPipelineConfig, ResearchPipelineInput,
@@ -50,6 +51,61 @@ fn format_naive_datetimes(values: Vec<chrono::NaiveDateTime>) -> Vec<String> {
5051
values.into_iter().map(|v| v.format("%Y-%m-%d %H:%M:%S").to_string()).collect()
5152
}
5253

54+
fn parse_one_naive_datetime(value: &str) -> PyResult<chrono::NaiveDateTime> {
55+
chrono::NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S")
56+
.or_else(|_| {
57+
chrono::NaiveDate::parse_from_str(value, "%Y-%m-%d")
58+
.map(|d| d.and_hms_opt(0, 0, 0).expect("valid fixed midnight"))
59+
})
60+
.map_err(|e| {
61+
PyValueError::new_err(format!(
62+
"invalid datetime '{value}' (expected '%Y-%m-%d %H:%M:%S' or '%Y-%m-%d'): {e}"
63+
))
64+
})
65+
}
66+
67+
fn build_trades(
68+
timestamps: Vec<String>,
69+
prices: Vec<f64>,
70+
volumes: Vec<f64>,
71+
) -> PyResult<Vec<Trade>> {
72+
if timestamps.len() != prices.len() || prices.len() != volumes.len() {
73+
return Err(PyValueError::new_err(format!(
74+
"timestamps/prices/volumes length mismatch: {} / {} / {}",
75+
timestamps.len(),
76+
prices.len(),
77+
volumes.len()
78+
)));
79+
}
80+
let mut trades = Vec::with_capacity(prices.len());
81+
for i in 0..prices.len() {
82+
trades.push(Trade {
83+
timestamp: parse_one_naive_datetime(&timestamps[i])?,
84+
price: prices[i],
85+
volume: volumes[i],
86+
});
87+
}
88+
Ok(trades)
89+
}
90+
91+
fn bars_to_rows(bars: Vec<openquant::data_structures::StandardBar>) -> Vec<(String, String, f64, f64, f64, f64, f64, f64, usize)> {
92+
bars.into_iter()
93+
.map(|b| {
94+
(
95+
b.start_timestamp.format("%Y-%m-%d %H:%M:%S").to_string(),
96+
b.timestamp.format("%Y-%m-%d %H:%M:%S").to_string(),
97+
b.open,
98+
b.high,
99+
b.low,
100+
b.close,
101+
b.volume,
102+
b.dollar_value,
103+
b.tick_count,
104+
)
105+
})
106+
.collect()
107+
}
108+
53109
#[pyfunction(name = "calculate_value_at_risk")]
54110
fn risk_calculate_value_at_risk(returns: Vec<f64>, confidence_level: f64) -> PyResult<f64> {
55111
RiskMetrics::default().calculate_value_at_risk(&returns, confidence_level).map_err(to_py_err)
@@ -154,6 +210,66 @@ fn sampling_seq_bootstrap(
154210
openquant::sampling::seq_bootstrap(&ind_mat, sample_length, warmup_samples)
155211
}
156212

213+
#[pyfunction(name = "build_time_bars")]
214+
fn bars_build_time_bars(
215+
timestamps: Vec<String>,
216+
prices: Vec<f64>,
217+
volumes: Vec<f64>,
218+
interval_seconds: i64,
219+
) -> PyResult<Vec<(String, String, f64, f64, f64, f64, f64, f64, usize)>> {
220+
if interval_seconds <= 0 {
221+
return Err(PyValueError::new_err("interval_seconds must be > 0"));
222+
}
223+
let trades = build_trades(timestamps, prices, volumes)?;
224+
let bars = time_bars(&trades, chrono::Duration::seconds(interval_seconds));
225+
Ok(bars_to_rows(bars))
226+
}
227+
228+
#[pyfunction(name = "build_tick_bars")]
229+
fn bars_build_tick_bars(
230+
timestamps: Vec<String>,
231+
prices: Vec<f64>,
232+
volumes: Vec<f64>,
233+
ticks_per_bar: usize,
234+
) -> PyResult<Vec<(String, String, f64, f64, f64, f64, f64, f64, usize)>> {
235+
if ticks_per_bar == 0 {
236+
return Err(PyValueError::new_err("ticks_per_bar must be > 0"));
237+
}
238+
let trades = build_trades(timestamps, prices, volumes)?;
239+
let bars = standard_bars(&trades, ticks_per_bar as f64, StandardBarType::Tick);
240+
Ok(bars_to_rows(bars))
241+
}
242+
243+
#[pyfunction(name = "build_volume_bars")]
244+
fn bars_build_volume_bars(
245+
timestamps: Vec<String>,
246+
prices: Vec<f64>,
247+
volumes: Vec<f64>,
248+
volume_per_bar: f64,
249+
) -> PyResult<Vec<(String, String, f64, f64, f64, f64, f64, f64, usize)>> {
250+
if !volume_per_bar.is_finite() || volume_per_bar <= 0.0 {
251+
return Err(PyValueError::new_err("volume_per_bar must be > 0"));
252+
}
253+
let trades = build_trades(timestamps, prices, volumes)?;
254+
let bars = standard_bars(&trades, volume_per_bar, StandardBarType::Volume);
255+
Ok(bars_to_rows(bars))
256+
}
257+
258+
#[pyfunction(name = "build_dollar_bars")]
259+
fn bars_build_dollar_bars(
260+
timestamps: Vec<String>,
261+
prices: Vec<f64>,
262+
volumes: Vec<f64>,
263+
dollar_value_per_bar: f64,
264+
) -> PyResult<Vec<(String, String, f64, f64, f64, f64, f64, f64, usize)>> {
265+
if !dollar_value_per_bar.is_finite() || dollar_value_per_bar <= 0.0 {
266+
return Err(PyValueError::new_err("dollar_value_per_bar must be > 0"));
267+
}
268+
let trades = build_trades(timestamps, prices, volumes)?;
269+
let bars = standard_bars(&trades, dollar_value_per_bar, StandardBarType::Dollar);
270+
Ok(bars_to_rows(bars))
271+
}
272+
157273
#[pyfunction(name = "get_signal")]
158274
fn bet_sizing_get_signal(prob: Vec<f64>, num_classes: usize, pred: Option<Vec<f64>>) -> Vec<f64> {
159275
openquant::bet_sizing::get_signal(&prob, num_classes, pred.as_deref())
@@ -325,6 +441,14 @@ fn _core(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
325441
m.add_submodule(&sampling)?;
326442
m.add("sampling", sampling)?;
327443

444+
let bars = PyModule::new_bound(py, "bars")?;
445+
bars.add_function(wrap_pyfunction!(bars_build_time_bars, &bars)?)?;
446+
bars.add_function(wrap_pyfunction!(bars_build_tick_bars, &bars)?)?;
447+
bars.add_function(wrap_pyfunction!(bars_build_volume_bars, &bars)?)?;
448+
bars.add_function(wrap_pyfunction!(bars_build_dollar_bars, &bars)?)?;
449+
m.add_submodule(&bars)?;
450+
m.add("bars", bars)?;
451+
328452
let bet_sizing = PyModule::new_bound(py, "bet_sizing")?;
329453
bet_sizing.add_function(wrap_pyfunction!(bet_sizing_get_signal, &bet_sizing)?)?;
330454
bet_sizing.add_function(wrap_pyfunction!(bet_sizing_discrete_signal, &bet_sizing)?)?;

docs/python_bindings.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Input conventions:
5252
- `timestamps`: list of strings formatted as `%Y-%m-%d %H:%M:%S`
5353
- timestamp variants require `len(close) == len(timestamps)`
5454

55-
### `openquant.bars` (AFML Ch.2 event-driven bars)
55+
### `openquant.bars` (AFML Ch.2 event-driven bars; Rust core via PyO3)
5656
- `build_time_bars(df, interval="1d")`
5757
- `build_tick_bars(df, ticks_per_bar=50)`
5858
- `build_volume_bars(df, volume_per_bar=100_000.0)`

python/openquant/bars.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,62 @@
11
from __future__ import annotations
22

33
import math
4+
from typing import Callable
45

56
import polars as pl
67

8+
from . import _core
79
from . import data
810

911

10-
def _prepare(df: pl.DataFrame) -> pl.DataFrame:
11-
return data.clean_ohlcv(df)
12-
13-
14-
def _aggregate(df: pl.DataFrame) -> pl.DataFrame:
15-
out = (
16-
df.group_by(["symbol", "bar_id"])
17-
.agg(
18-
pl.col("ts").min().alias("start_ts"),
19-
pl.col("ts").max().alias("end_ts"),
20-
pl.col("open").first().alias("open"),
21-
pl.col("high").max().alias("high"),
22-
pl.col("low").min().alias("low"),
23-
pl.col("close").last().alias("close"),
24-
pl.col("adj_close").last().alias("adj_close"),
25-
pl.col("volume").sum().alias("volume"),
26-
pl.len().alias("n_obs"),
12+
def _interval_to_seconds(interval: str) -> int:
13+
s = interval.strip().lower()
14+
if s.endswith("d"):
15+
return int(s[:-1]) * 24 * 3600
16+
if s.endswith("h"):
17+
return int(s[:-1]) * 3600
18+
if s.endswith("m"):
19+
return int(s[:-1]) * 60
20+
if s.endswith("s"):
21+
return int(s[:-1])
22+
raise ValueError(f"unsupported interval format: {interval}")
23+
24+
25+
def _rows_to_frame(symbol: str, rows: list[tuple[str, str, float, float, float, float, float, float, int]]) -> pl.DataFrame:
26+
if not rows:
27+
return pl.DataFrame(
28+
{
29+
"ts": [],
30+
"symbol": [],
31+
"open": [],
32+
"high": [],
33+
"low": [],
34+
"close": [],
35+
"volume": [],
36+
"adj_close": [],
37+
"start_ts": [],
38+
"n_obs": [],
39+
"dollar_value": [],
40+
}
2741
)
28-
.sort(["symbol", "end_ts", "bar_id"])
29-
.drop("bar_id")
30-
.rename({"end_ts": "ts"})
31-
)
32-
return out.select(
42+
return pl.DataFrame(
43+
{
44+
"start_ts": [r[0] for r in rows],
45+
"ts": [r[1] for r in rows],
46+
"open": [r[2] for r in rows],
47+
"high": [r[3] for r in rows],
48+
"low": [r[4] for r in rows],
49+
"close": [r[5] for r in rows],
50+
"volume": [r[6] for r in rows],
51+
"dollar_value": [r[7] for r in rows],
52+
"n_obs": [r[8] for r in rows],
53+
}
54+
).with_columns(
55+
pl.lit(symbol).alias("symbol"),
56+
pl.col("start_ts").str.strptime(pl.Datetime, strict=False),
57+
pl.col("ts").str.strptime(pl.Datetime, strict=False),
58+
pl.col("close").alias("adj_close"),
59+
).select(
3360
[
3461
"ts",
3562
"symbol",
@@ -41,44 +68,46 @@ def _aggregate(df: pl.DataFrame) -> pl.DataFrame:
4168
"adj_close",
4269
"start_ts",
4370
"n_obs",
71+
"dollar_value",
4472
]
4573
)
4674

4775

76+
def _build_by_symbol(
77+
df: pl.DataFrame,
78+
rust_builder: Callable[[list[str], list[float], list[float], float | int], list[tuple[str, str, float, float, float, float, float, float, int]]],
79+
param: float | int,
80+
) -> pl.DataFrame:
81+
clean = data.clean_ohlcv(df).sort(["symbol", "ts"])
82+
out_frames: list[pl.DataFrame] = []
83+
for symbol in clean["symbol"].unique().to_list():
84+
sdf = clean.filter(pl.col("symbol") == symbol).sort("ts")
85+
rows = rust_builder(
86+
[str(x) for x in sdf["ts"].to_list()],
87+
[float(x) for x in sdf["close"].to_list()],
88+
[float(x) for x in sdf["volume"].to_list()],
89+
param,
90+
)
91+
out_frames.append(_rows_to_frame(symbol, rows))
92+
if not out_frames:
93+
return _rows_to_frame("", [])
94+
return pl.concat(out_frames, how="vertical").sort(["symbol", "ts"])
95+
96+
4897
def build_time_bars(df: pl.DataFrame, *, interval: str = "1d") -> pl.DataFrame:
49-
clean = _prepare(df)
50-
grouped = clean.with_columns(pl.col("ts").dt.truncate(interval).alias("bar_id"))
51-
return _aggregate(grouped)
98+
return _build_by_symbol(df, _core.bars.build_time_bars, _interval_to_seconds(interval))
5299

53100

54101
def build_tick_bars(df: pl.DataFrame, *, ticks_per_bar: int = 50) -> pl.DataFrame:
55102
if ticks_per_bar <= 0:
56103
raise ValueError("ticks_per_bar must be > 0")
57-
clean = _prepare(df)
58-
grouped = clean.with_columns(
59-
(pl.int_range(0, pl.len()).over("symbol") // ticks_per_bar)
60-
.cast(pl.Int64)
61-
.alias("bar_id")
62-
)
63-
return _aggregate(grouped)
104+
return _build_by_symbol(df, _core.bars.build_tick_bars, ticks_per_bar)
64105

65106

66107
def build_volume_bars(df: pl.DataFrame, *, volume_per_bar: float = 100_000.0) -> pl.DataFrame:
67108
if volume_per_bar <= 0:
68109
raise ValueError("volume_per_bar must be > 0")
69-
clean = _prepare(df)
70-
eps = volume_per_bar * 1e-9
71-
grouped = (
72-
clean.with_columns(pl.col("volume").cum_sum().over("symbol").alias("cum_volume"))
73-
.with_columns(
74-
(((pl.col("cum_volume") - eps).clip(lower_bound=0.0)) / volume_per_bar)
75-
.floor()
76-
.cast(pl.Int64)
77-
.alias("bar_id")
78-
)
79-
.drop("cum_volume")
80-
)
81-
return _aggregate(grouped)
110+
return _build_by_symbol(df, _core.bars.build_volume_bars, volume_per_bar)
82111

83112

84113
def build_dollar_bars(
@@ -88,20 +117,7 @@ def build_dollar_bars(
88117
) -> pl.DataFrame:
89118
if dollar_value_per_bar <= 0:
90119
raise ValueError("dollar_value_per_bar must be > 0")
91-
clean = _prepare(df)
92-
eps = dollar_value_per_bar * 1e-9
93-
grouped = (
94-
clean.with_columns((pl.col("close") * pl.col("volume")).alias("dollar_value"))
95-
.with_columns(pl.col("dollar_value").cum_sum().over("symbol").alias("cum_dollar"))
96-
.with_columns(
97-
(((pl.col("cum_dollar") - eps).clip(lower_bound=0.0)) / dollar_value_per_bar)
98-
.floor()
99-
.cast(pl.Int64)
100-
.alias("bar_id")
101-
)
102-
.drop(["dollar_value", "cum_dollar"])
103-
)
104-
return _aggregate(grouped)
120+
return _build_by_symbol(df, _core.bars.build_dollar_bars, dollar_value_per_bar)
105121

106122

107123
def _lag1_autocorr(values: list[float]) -> float:
@@ -120,7 +136,7 @@ def _lag1_autocorr(values: list[float]) -> float:
120136

121137

122138
def bar_diagnostics(df: pl.DataFrame) -> dict[str, float]:
123-
clean = _prepare(df).sort(["symbol", "ts"])
139+
clean = data.clean_ohlcv(df).sort(["symbol", "ts"])
124140
returns = (
125141
clean.with_columns(
126142
(

0 commit comments

Comments
 (0)