diff --git a/src/strategy/mean_reversion.rs b/src/strategy/mean_reversion.rs new file mode 100644 index 0000000..d60f3be --- /dev/null +++ b/src/strategy/mean_reversion.rs @@ -0,0 +1,232 @@ +use std::collections::HashMap; + +use crate::{ + model::ffi::{self}, + strategy::{Strategy, StrategyParams}, + types::{Bar, Signal}, +}; + +pub struct MeanReversionStrategy { + name: String, + symbols: Vec, + window: i32, + zscore_entry: f64, + zscore_exit: f64, + min_half_life: f64, + max_half_life: f64, + adf_lags: i32, + recompute_interval: usize, + max_buffer: usize, + startup_lookback: String, + startup_bar_limit: u32, + price_buffers: HashMap>, + bar_counts: HashMap, + signals: HashMap, +} + +struct SymbolState { + zscore: f64, + half_life: f64, + equilibrium: f64, + volatility: f64, + theta: f64, + is_stationary: bool, + adf_confidence: AdfConfidence, +} + +#[derive(Clone, Copy)] +enum AdfConfidence { + None, + TenPercent, + FivePercent, + OnePercent, +} + +impl AdfConfidence { + fn weight(self) -> f64 { + match self { + AdfConfidence::None => 0.0, + AdfConfidence::TenPercent => 0.5, + AdfConfidence::FivePercent => 0.75, + AdfConfidence::OnePercent => 1.0, + } + } +} + +pub struct MeanReversionConfig { + pub name: String, + pub symbols: Vec, + pub window: i32, + pub zscore_entry: f64, + pub zscore_exit: f64, + pub min_half_life: f64, + pub max_half_life: f64, + pub adf_lags: i32, + pub recompute_interval: usize, + pub max_buffer: usize, + pub startup_lookback: String, + pub startup_bar_limit: u32, +} + +impl MeanReversionStrategy { + pub fn new(config: MeanReversionConfig) -> Self { + Self { + name: config.name, + symbols: config.symbols, + window: config.window, + zscore_entry: config.zscore_entry, + zscore_exit: config.zscore_exit, + min_half_life: config.min_half_life, + max_half_life: config.max_half_life, + adf_lags: config.adf_lags, + recompute_interval: config.recompute_interval, + max_buffer: config.max_buffer, + startup_lookback: config.startup_lookback, + startup_bar_limit: config.startup_bar_limit, + price_buffers: HashMap::new(), + bar_counts: HashMap::new(), + signals: HashMap::new(), + } + } + + fn recompute(&mut self, symbol: &str) { + let prices = match self.price_buffers.get(symbol) { + Some(p) => p, + None => return, + }; + + let zscores = ffi::zscore(prices, self.window); + let (speeds, equilibria, volatility_sq) = ffi::ou_estimate(prices, self.window); + let adf_result = ffi::adf(prices, self.adf_lags); + + let last = prices.len() - 1; + let speed = speeds[last]; + + let half_life = if speed > 0.0 && !speed.is_nan() { + (2.0_f64).ln() / speed + } else { + f64::NAN + }; + + let adf_confidence = if adf_result.reject_1pct { + AdfConfidence::OnePercent + } else if adf_result.reject_5pct { + AdfConfidence::FivePercent + } else if adf_result.reject_10pct { + AdfConfidence::TenPercent + } else { + AdfConfidence::None + }; + + self.signals.insert( + symbol.to_string(), + SymbolState { + zscore: zscores[last], + half_life, + equilibrium: equilibria[last], + volatility: volatility_sq[last].sqrt(), + theta: speed, + is_stationary: adf_result.reject_10pct, + adf_confidence, + }, + ); + } + + fn generate_signal(&self, symbol: &str, price: f64) -> Option { + let state = self.signals.get(symbol)?; + + if !state.is_stationary { + return None; + } + + if state.half_life.is_nan() + || state.half_life < self.min_half_life + || state.half_life > self.max_half_life + { + return None; + } + + let z = state.zscore; + if z.is_nan() { + return None; + } + + let baseline_vol = 0.02; + let vol_scalar = (baseline_vol / state.volatility).clamp(0.25, 1.5); + let raw_strength = (z.abs() / self.zscore_entry).min(1.5); + let strength = (raw_strength * state.adf_confidence.weight() * vol_scalar).min(1.0); + + let expected_move = state.theta * (state.equilibrium - price); + let edge = expected_move / state.volatility; + + if z < -self.zscore_entry { + Some(Signal::Buy { + symbol: symbol.to_string(), + strength, + }) + } else if edge < self.zscore_exit { + Some(Signal::Sell { + symbol: symbol.to_string(), + strength, + }) + } else { + None + } + } +} + +impl Strategy for MeanReversionStrategy { + fn name(&self) -> &str { + &self.name + } + + fn symbols(&self) -> &[String] { + &self.symbols + } + + fn on_bar(&mut self, bar: &Bar) -> Option { + let buffer = self.price_buffers.entry(bar.symbol.clone()).or_default(); + buffer.push(bar.close); + if buffer.len() > self.max_buffer { + buffer.remove(0); + } + + let count = self.bar_counts.entry(bar.symbol.clone()).or_insert(0); + *count += 1; + + if (buffer.len() as i32) < self.window { + return None; + } + + if (*count).is_multiple_of(self.recompute_interval) { + self.recompute(&bar.symbol); + } + + self.generate_signal(&bar.symbol, bar.close) + } + + fn reset(&mut self) { + self.price_buffers.clear(); + self.bar_counts.clear(); + self.signals.clear(); + } + + fn startup_config(&self) -> Option<(String, u32)> { + Some((self.startup_lookback.clone(), self.startup_bar_limit)) + } + + fn params(&self) -> super::StrategyParams { + StrategyParams::MeanReversion { + window: self.window, + zscore_entry: self.zscore_entry, + zscore_exit: self.zscore_exit, + min_half_life: self.min_half_life, + max_half_life: self.max_half_life, + adf_lags: self.adf_lags, + recompute_interval: self.recompute_interval, + max_buffer: self.max_buffer, + startup_lookback: self.startup_lookback.clone(), + startup_bar_limit: self.startup_bar_limit, + } + } +} diff --git a/src/strategy/mod.rs b/src/strategy/mod.rs index 743cbb0..b815e9b 100644 --- a/src/strategy/mod.rs +++ b/src/strategy/mod.rs @@ -1,3 +1,4 @@ +pub mod mean_reversion; pub mod momentum; use crate::types::{Bar, Signal}; @@ -21,8 +22,16 @@ pub enum StrategyParams { startup_bar_limit: u32, }, MeanReversion { - window: usize, - std_devs: f64, + window: i32, + zscore_entry: f64, + zscore_exit: f64, + min_half_life: f64, + max_half_life: f64, + adf_lags: i32, + recompute_interval: usize, + max_buffer: usize, + startup_lookback: String, + startup_bar_limit: u32, }, Custom { params: std::collections::HashMap, @@ -57,11 +66,32 @@ pub fn create_strategy(config: &StrategySettings) -> Box { *startup_bar_limit, )), StrategyParams::MeanReversion { - window: _, - std_devs: _, - } => { - todo!("Need to implement") - } + window, + zscore_entry, + zscore_exit, + min_half_life, + max_half_life, + adf_lags, + recompute_interval, + max_buffer, + startup_lookback, + startup_bar_limit, + } => Box::new(mean_reversion::MeanReversionStrategy::new( + mean_reversion::MeanReversionConfig { + name: config.name.clone(), + symbols: config.symbols.clone(), + window: *window, + zscore_entry: *zscore_entry, + zscore_exit: *zscore_exit, + min_half_life: *min_half_life, + max_half_life: *max_half_life, + adf_lags: *adf_lags, + recompute_interval: *recompute_interval, + max_buffer: *max_buffer, + startup_lookback: startup_lookback.clone(), + startup_bar_limit: *startup_bar_limit, + }, + )), StrategyParams::Custom { .. } => { todo!("Need to implement") } diff --git a/streamlit/pages/overview.py b/streamlit/pages/overview.py index c63bfdf..3fc873f 100644 --- a/streamlit/pages/overview.py +++ b/streamlit/pages/overview.py @@ -36,7 +36,7 @@ st.subheader("Equity Curve") history = get_account_history(account_id, 500) if not history.empty: - history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601') + history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601', utc=True) history = history.sort_values('timestamp') fig = go.Figure() diff --git a/streamlit/strategies/mean_reversion.py b/streamlit/strategies/mean_reversion.py index bf93ca4..f0df3ce 100644 --- a/streamlit/strategies/mean_reversion.py +++ b/streamlit/strategies/mean_reversion.py @@ -10,16 +10,28 @@ def strategy_type(self) -> str: def default_params(self) -> dict: return { - "window": 20, - "std_devs": 2.0, + "window": 60, + "zscore_entry": 2.0, + "zscore_exit": 0.5, + "min_half_life": 2.0, + "max_half_life": 50.0, + "adf_lags": 1, + "recompute_interval": 5, + "max_buffer": 500, "startup_lookback": "1Hour", - "startup_bar_limit": 50, + "startup_bar_limit": 500, } def param_descriptions(self) -> dict: return { - "window": "Rolling window size for computing the moving average.", - "std_devs": "Number of standard deviations from the mean to trigger a signal.", + "window": "Rolling window size for z-score and OU parameter estimation.", + "zscore_entry": "Z-score threshold to enter a trade. Higher = fewer but stronger signals.", + "zscore_exit": "Z-score threshold to exit. When price reverts past equilibrium, close the position.", + "min_half_life": "Minimum mean-reversion half-life in bars. Below this, signal is noise.", + "max_half_life": "Maximum mean-reversion half-life in bars. Above this, capital is tied up too long.", + "adf_lags": "Number of lags for the Augmented Dickey-Fuller stationarity test.", + "recompute_interval": "Recompute GPU signals every N bars. Lower = more responsive, higher = less GPU usage.", + "max_buffer": "Maximum price history to keep per symbol. Also controls startup bar fetch.", "startup_lookback": "Bar timeframe used during historical warmup before live trading begins.", "startup_bar_limit": "Number of historical bars to fetch during warmup.", } @@ -29,24 +41,49 @@ def render_params(self, st) -> dict: col1, col2 = st.columns(2) with col1: - window = st.number_input("Window Size", min_value=5, value=20) + window = st.number_input("Window Size", min_value=5, value=60) st.caption(descs["window"]) + zscore_entry = st.number_input("Z-Score Entry", min_value=0.5, value=2.0, step=0.1, format="%.1f") + st.caption(descs["zscore_entry"]) + + min_half_life = st.number_input("Min Half-Life", min_value=1.0, value=2.0, step=1.0, format="%.1f") + st.caption(descs["min_half_life"]) + + adf_lags = st.number_input("ADF Lags", min_value=1, value=1) + st.caption(descs["adf_lags"]) + startup_lookback = st.selectbox( "Startup Timeframe", ["1Min", "5Min", "15Min", "30Min", "1Hour", "1Day"], index=4 ) st.caption(descs["startup_lookback"]) with col2: - std_devs = st.number_input("Std Deviations", min_value=0.5, value=2.0, step=0.1, format="%.1f") - st.caption(descs["std_devs"]) + zscore_exit = st.number_input("Z-Score Exit", min_value=0.0, value=0.5, step=0.1, format="%.1f") + st.caption(descs["zscore_exit"]) + + max_half_life = st.number_input("Max Half-Life", min_value=5.0, value=50.0, step=5.0, format="%.1f") + st.caption(descs["max_half_life"]) + + recompute_interval = st.number_input("Recompute Interval", min_value=1, value=5) + st.caption(descs["recompute_interval"]) + + max_buffer = st.number_input("Max Buffer", min_value=50, value=500, step=50) + st.caption(descs["max_buffer"]) - startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=50) + startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=500) st.caption(descs["startup_bar_limit"]) return { "window": window, - "std_devs": std_devs, + "zscore_entry": zscore_entry, + "zscore_exit": zscore_exit, + "min_half_life": min_half_life, + "max_half_life": max_half_life, + "adf_lags": adf_lags, + "recompute_interval": recompute_interval, + "max_buffer": max_buffer, "startup_lookback": startup_lookback, "startup_bar_limit": startup_bar_limit, } + \ No newline at end of file