Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions src/strategy/mean_reversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
use std::collections::HashMap;

use crate::{
model::ffi::{self},
strategy::{Strategy, StrategyParams},
types::{Bar, Signal},
};

pub struct MeanReversionStrategy {
name: String,
symbols: Vec<String>,
window: i32,
zscore_entry: f64,
zscore_exit: f64,
min_half_life: f64,
max_half_life: f64,
adf_lags: i32,
recompute_interval: usize,
max_buffer: usize,
startup_lookback: String,
startup_bar_limit: u32,
price_buffers: HashMap<String, Vec<f64>>,
bar_counts: HashMap<String, usize>,
signals: HashMap<String, SymbolState>,
}

struct SymbolState {
zscore: f64,
half_life: f64,
equilibrium: f64,
volatility: f64,
theta: f64,
is_stationary: bool,
adf_confidence: AdfConfidence,
}

#[derive(Clone, Copy)]
enum AdfConfidence {
None,
TenPercent,
FivePercent,
OnePercent,
}

impl AdfConfidence {
fn weight(self) -> f64 {
match self {
AdfConfidence::None => 0.0,
AdfConfidence::TenPercent => 0.5,
AdfConfidence::FivePercent => 0.75,
AdfConfidence::OnePercent => 1.0,
}
}
}

pub struct MeanReversionConfig {
pub name: String,
pub symbols: Vec<String>,
pub window: i32,
pub zscore_entry: f64,
pub zscore_exit: f64,
pub min_half_life: f64,
pub max_half_life: f64,
pub adf_lags: i32,
pub recompute_interval: usize,
pub max_buffer: usize,
pub startup_lookback: String,
pub startup_bar_limit: u32,
}

impl MeanReversionStrategy {
pub fn new(config: MeanReversionConfig) -> Self {
Self {
name: config.name,
symbols: config.symbols,
window: config.window,
zscore_entry: config.zscore_entry,
zscore_exit: config.zscore_exit,
min_half_life: config.min_half_life,
max_half_life: config.max_half_life,
adf_lags: config.adf_lags,
recompute_interval: config.recompute_interval,
max_buffer: config.max_buffer,
startup_lookback: config.startup_lookback,
startup_bar_limit: config.startup_bar_limit,
price_buffers: HashMap::new(),
bar_counts: HashMap::new(),
signals: HashMap::new(),
}
}

fn recompute(&mut self, symbol: &str) {
let prices = match self.price_buffers.get(symbol) {
Some(p) => p,
None => return,
};

let zscores = ffi::zscore(prices, self.window);
let (speeds, equilibria, volatility_sq) = ffi::ou_estimate(prices, self.window);
let adf_result = ffi::adf(prices, self.adf_lags);

let last = prices.len() - 1;
let speed = speeds[last];

let half_life = if speed > 0.0 && !speed.is_nan() {
(2.0_f64).ln() / speed
} else {
f64::NAN
};

let adf_confidence = if adf_result.reject_1pct {
AdfConfidence::OnePercent
} else if adf_result.reject_5pct {
AdfConfidence::FivePercent
} else if adf_result.reject_10pct {
AdfConfidence::TenPercent
} else {
AdfConfidence::None
};

self.signals.insert(
symbol.to_string(),
SymbolState {
zscore: zscores[last],
half_life,
equilibrium: equilibria[last],
volatility: volatility_sq[last].sqrt(),
theta: speed,
is_stationary: adf_result.reject_10pct,
adf_confidence,
},
);
}

fn generate_signal(&self, symbol: &str, price: f64) -> Option<Signal> {
let state = self.signals.get(symbol)?;

if !state.is_stationary {
return None;
}

if state.half_life.is_nan()
|| state.half_life < self.min_half_life
|| state.half_life > self.max_half_life
{
return None;
}

let z = state.zscore;
if z.is_nan() {
return None;
}

let baseline_vol = 0.02;
let vol_scalar = (baseline_vol / state.volatility).clamp(0.25, 1.5);
let raw_strength = (z.abs() / self.zscore_entry).min(1.5);
let strength = (raw_strength * state.adf_confidence.weight() * vol_scalar).min(1.0);

let expected_move = state.theta * (state.equilibrium - price);
let edge = expected_move / state.volatility;

if z < -self.zscore_entry {
Some(Signal::Buy {
symbol: symbol.to_string(),
strength,
})
} else if edge < self.zscore_exit {
Some(Signal::Sell {
symbol: symbol.to_string(),
strength,
})
} else {
None
}
}
}

impl Strategy for MeanReversionStrategy {
fn name(&self) -> &str {
&self.name
}

fn symbols(&self) -> &[String] {
&self.symbols
}

fn on_bar(&mut self, bar: &Bar) -> Option<Signal> {
let buffer = self.price_buffers.entry(bar.symbol.clone()).or_default();
buffer.push(bar.close);
if buffer.len() > self.max_buffer {
buffer.remove(0);
}

let count = self.bar_counts.entry(bar.symbol.clone()).or_insert(0);
*count += 1;

if (buffer.len() as i32) < self.window {
return None;
}

if (*count).is_multiple_of(self.recompute_interval) {
self.recompute(&bar.symbol);
}

self.generate_signal(&bar.symbol, bar.close)
}

fn reset(&mut self) {
self.price_buffers.clear();
self.bar_counts.clear();
self.signals.clear();
}

fn startup_config(&self) -> Option<(String, u32)> {
Some((self.startup_lookback.clone(), self.startup_bar_limit))
}

fn params(&self) -> super::StrategyParams {
StrategyParams::MeanReversion {
window: self.window,
zscore_entry: self.zscore_entry,
zscore_exit: self.zscore_exit,
min_half_life: self.min_half_life,
max_half_life: self.max_half_life,
adf_lags: self.adf_lags,
recompute_interval: self.recompute_interval,
max_buffer: self.max_buffer,
startup_lookback: self.startup_lookback.clone(),
startup_bar_limit: self.startup_bar_limit,
}
}
}
44 changes: 37 additions & 7 deletions src/strategy/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod mean_reversion;
pub mod momentum;

use crate::types::{Bar, Signal};
Expand All @@ -21,8 +22,16 @@ pub enum StrategyParams {
startup_bar_limit: u32,
},
MeanReversion {
window: usize,
std_devs: f64,
window: i32,
zscore_entry: f64,
zscore_exit: f64,
min_half_life: f64,
max_half_life: f64,
adf_lags: i32,
recompute_interval: usize,
max_buffer: usize,
startup_lookback: String,
startup_bar_limit: u32,
},
Custom {
params: std::collections::HashMap<String, f64>,
Expand Down Expand Up @@ -57,11 +66,32 @@ pub fn create_strategy(config: &StrategySettings) -> Box<dyn Strategy> {
*startup_bar_limit,
)),
StrategyParams::MeanReversion {
window: _,
std_devs: _,
} => {
todo!("Need to implement")
}
window,
zscore_entry,
zscore_exit,
min_half_life,
max_half_life,
adf_lags,
recompute_interval,
max_buffer,
startup_lookback,
startup_bar_limit,
} => Box::new(mean_reversion::MeanReversionStrategy::new(
mean_reversion::MeanReversionConfig {
name: config.name.clone(),
symbols: config.symbols.clone(),
window: *window,
zscore_entry: *zscore_entry,
zscore_exit: *zscore_exit,
min_half_life: *min_half_life,
max_half_life: *max_half_life,
adf_lags: *adf_lags,
recompute_interval: *recompute_interval,
max_buffer: *max_buffer,
startup_lookback: startup_lookback.clone(),
startup_bar_limit: *startup_bar_limit,
},
)),
StrategyParams::Custom { .. } => {
todo!("Need to implement")
}
Expand Down
2 changes: 1 addition & 1 deletion streamlit/pages/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
st.subheader("Equity Curve")
history = get_account_history(account_id, 500)
if not history.empty:
history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601')
history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601', utc=True)
history = history.sort_values('timestamp')

fig = go.Figure()
Expand Down
57 changes: 47 additions & 10 deletions streamlit/strategies/mean_reversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,28 @@ def strategy_type(self) -> str:

def default_params(self) -> dict:
return {
"window": 20,
"std_devs": 2.0,
"window": 60,
"zscore_entry": 2.0,
"zscore_exit": 0.5,
"min_half_life": 2.0,
"max_half_life": 50.0,
"adf_lags": 1,
"recompute_interval": 5,
"max_buffer": 500,
"startup_lookback": "1Hour",
"startup_bar_limit": 50,
"startup_bar_limit": 500,
}

def param_descriptions(self) -> dict:
return {
"window": "Rolling window size for computing the moving average.",
"std_devs": "Number of standard deviations from the mean to trigger a signal.",
"window": "Rolling window size for z-score and OU parameter estimation.",
"zscore_entry": "Z-score threshold to enter a trade. Higher = fewer but stronger signals.",
"zscore_exit": "Z-score threshold to exit. When price reverts past equilibrium, close the position.",
"min_half_life": "Minimum mean-reversion half-life in bars. Below this, signal is noise.",
"max_half_life": "Maximum mean-reversion half-life in bars. Above this, capital is tied up too long.",
"adf_lags": "Number of lags for the Augmented Dickey-Fuller stationarity test.",
"recompute_interval": "Recompute GPU signals every N bars. Lower = more responsive, higher = less GPU usage.",
"max_buffer": "Maximum price history to keep per symbol. Also controls startup bar fetch.",
"startup_lookback": "Bar timeframe used during historical warmup before live trading begins.",
"startup_bar_limit": "Number of historical bars to fetch during warmup.",
}
Expand All @@ -29,24 +41,49 @@ def render_params(self, st) -> dict:

col1, col2 = st.columns(2)
with col1:
window = st.number_input("Window Size", min_value=5, value=20)
window = st.number_input("Window Size", min_value=5, value=60)
st.caption(descs["window"])

zscore_entry = st.number_input("Z-Score Entry", min_value=0.5, value=2.0, step=0.1, format="%.1f")
st.caption(descs["zscore_entry"])

min_half_life = st.number_input("Min Half-Life", min_value=1.0, value=2.0, step=1.0, format="%.1f")
st.caption(descs["min_half_life"])

adf_lags = st.number_input("ADF Lags", min_value=1, value=1)
st.caption(descs["adf_lags"])

startup_lookback = st.selectbox(
"Startup Timeframe", ["1Min", "5Min", "15Min", "30Min", "1Hour", "1Day"], index=4
)
st.caption(descs["startup_lookback"])

with col2:
std_devs = st.number_input("Std Deviations", min_value=0.5, value=2.0, step=0.1, format="%.1f")
st.caption(descs["std_devs"])
zscore_exit = st.number_input("Z-Score Exit", min_value=0.0, value=0.5, step=0.1, format="%.1f")
st.caption(descs["zscore_exit"])

max_half_life = st.number_input("Max Half-Life", min_value=5.0, value=50.0, step=5.0, format="%.1f")
st.caption(descs["max_half_life"])

recompute_interval = st.number_input("Recompute Interval", min_value=1, value=5)
st.caption(descs["recompute_interval"])

max_buffer = st.number_input("Max Buffer", min_value=50, value=500, step=50)
st.caption(descs["max_buffer"])

startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=50)
startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=500)
st.caption(descs["startup_bar_limit"])

return {
"window": window,
"std_devs": std_devs,
"zscore_entry": zscore_entry,
"zscore_exit": zscore_exit,
"min_half_life": min_half_life,
"max_half_life": max_half_life,
"adf_lags": adf_lags,
"recompute_interval": recompute_interval,
"max_buffer": max_buffer,
"startup_lookback": startup_lookback,
"startup_bar_limit": startup_bar_limit,
}

Loading