diff --git a/Cargo.lock b/Cargo.lock index 3abd532..3b9cf62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,9 +57,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.16.0" +version = "1.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9a7b350e3bb1767102698302bc37256cbd48422809984b98d292c40e2579aa9" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" dependencies = [ "aws-lc-sys", "zeroize", @@ -67,9 +67,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.37.1" +version = "0.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" dependencies = [ "cc", "cmake", @@ -157,9 +157,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", @@ -286,6 +286,7 @@ dependencies = [ "dotenv", "reqwest", "serde", + "serde_json", "sqlx", "tempfile", "thiserror 2.0.18", @@ -556,20 +557,20 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 6.0.0", "wasip2", "wasip3", ] @@ -906,9 +907,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" @@ -960,9 +961,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.87" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -985,9 +986,9 @@ checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" [[package]] name = "libc" -version = "0.2.182" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libm" @@ -997,13 +998,14 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.12" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" dependencies = [ "bitflags", "libc", - "redox_syscall 0.7.1", + "plain", + "redox_syscall 0.7.3", ] [[package]] @@ -1019,9 +1021,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" @@ -1187,9 +1189,9 @@ checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -1224,6 +1226,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1319,9 +1327,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.44" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -1332,6 +1340,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.8.5" @@ -1402,9 +1416,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.7.1" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35985aa610addc02e24fc232012c86fd11f14111180f902b67e2d5331f8ebf2b" +checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" dependencies = [ "bitflags", ] @@ -1491,9 +1505,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ "bitflags", "errno", @@ -1504,9 +1518,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.36" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "aws-lc-rs", "once_cell", @@ -1771,12 +1785,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2071,12 +2085,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.25.0" +version = "3.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" dependencies = [ "fastrand", - "getrandom 0.4.1", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys 0.61.2", @@ -2149,9 +2163,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.49.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", @@ -2166,9 +2180,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.6.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", @@ -2407,11 +2421,11 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.21.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" dependencies = [ - "getrandom 0.4.1", + "getrandom 0.4.2", "js-sys", "wasm-bindgen", ] @@ -2479,9 +2493,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.110" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -2492,9 +2506,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.60" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ "cfg-if", "futures-util", @@ -2506,9 +2520,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.110" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2516,9 +2530,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.110" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ "bumpalo", "proc-macro2", @@ -2529,9 +2543,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.110" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] @@ -2572,9 +2586,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.87" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -2978,9 +2992,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" [[package]] name = "wit-bindgen" @@ -3101,18 +3115,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.39" +version = "0.8.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.39" +version = "0.8.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index ccf7e50..de8bfe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,8 @@ chrono = { version = "0.4.43", features = ["clock"] } dotenv = "0.15.0" reqwest = { version = "0.13.2", features = ["json"] } serde = { version = "1.0.228", features = ["derive"] } -sqlx = { version = "0.8.6", features = ["runtime-tokio", "sqlite", "chrono", "uuid"] } +serde_json = "1.0.149" +sqlx = { version = "0.8.6", features = ["runtime-tokio", "sqlite", "chrono", "uuid", "derive"] } thiserror = "2.0.18" tokio = { version = "1.49.0", features = ["full"] } toml = "0.9.11" diff --git a/src/broker/alpaca.rs b/src/broker/alpaca.rs index 83e1229..9cd1166 100644 --- a/src/broker/alpaca.rs +++ b/src/broker/alpaca.rs @@ -15,8 +15,9 @@ struct AlpacaApiAccount { } impl AlpacaApiAccount { - fn into_account(self) -> Account { + fn into_account(self, id: &str) -> Account { Account { + id: id.to_string(), equity: self.equity.parse().unwrap_or(0.0), cash: self.cash.parse().unwrap_or(0.0), buying_power: self.buying_power.parse().unwrap_or(0.0), @@ -88,6 +89,7 @@ impl AlpacaApiOrderRequest { } pub struct AlpacaApiBroker { + account_id: String, client: Client, base_url: String, api_key: String, @@ -95,8 +97,15 @@ pub struct AlpacaApiBroker { } impl AlpacaApiBroker { - pub fn new(api_endpoint: &str, api_key: &str, api_secret: &str, _paper: bool) -> Self { + pub fn new( + account_id: &str, + api_endpoint: &str, + api_key: &str, + api_secret: &str, + _paper: bool, + ) -> Self { Self { + account_id: account_id.to_string(), client: Client::new(), base_url: { api_endpoint }.to_string(), api_key: api_key.to_string(), @@ -129,7 +138,7 @@ impl AlpacaApiBroker { .await .map_err(|e| BrokerError::Fail(e.to_string()))?; - Ok(api_account.into_account()) + Ok(api_account.into_account(&self.account_id)) } pub async fn fetch_positions(&self) -> Result, BrokerError> { diff --git a/src/broker/paper.rs b/src/broker/paper.rs index 3bf1bcb..ff1e87f 100644 --- a/src/broker/paper.rs +++ b/src/broker/paper.rs @@ -15,9 +15,10 @@ pub struct PaperBroker { } impl PaperBroker { - pub fn new(starting_cash: f64) -> Self { + pub fn new(id: &str, starting_cash: f64) -> Self { Self { account: Mutex::new(Account { + id: id.to_string(), equity: starting_cash, cash: starting_cash, buying_power: starting_cash, diff --git a/src/broker/paper_alpaca.rs b/src/broker/paper_alpaca.rs index d943efe..84eaaa0 100644 --- a/src/broker/paper_alpaca.rs +++ b/src/broker/paper_alpaca.rs @@ -7,13 +7,18 @@ use crate::{ use async_trait::async_trait; pub struct PaperAlpacaBroker { + account_id: String, storage: Storage, data: MarketData, } impl PaperAlpacaBroker { - pub fn new(storage: Storage, data: MarketData) -> Self { - Self { storage, data } + pub fn new(account_id: &str, storage: Storage, data: MarketData) -> Self { + Self { + account_id: account_id.to_string(), + storage, + data, + } } } @@ -21,27 +26,26 @@ impl PaperAlpacaBroker { impl Broker for PaperAlpacaBroker { async fn get_account(&self) -> Result { self.storage - .get_account() + .get_account(&self.account_id) .await .map_err(|e| BrokerError::Fail(e.to_string())) } async fn get_positions(&self) -> Result, BrokerError> { self.storage - .get_positions() + .get_positions(&self.account_id) .await .map_err(|e| BrokerError::Fail(e.to_string())) } async fn get_position(&self, symbol: &str) -> Result, BrokerError> { self.storage - .get_position(symbol) + .get_position(&self.account_id, symbol) .await .map_err(|e| BrokerError::Fail(e.to_string())) } async fn submit_order(&self, order: &Order) -> Result { - // Fetch current price let bar = self .data .get_latest_bar(&order.symbol) @@ -55,13 +59,13 @@ impl Broker for PaperAlpacaBroker { match order.side { Side::Buy => { self.storage - .deduct_cash(order_value) + .deduct_cash(&self.account_id, order_value) .await .map_err(|_| BrokerError::InsufficientFunds)?; let existing = self .storage - .get_position(&order.symbol) + .get_position(&self.account_id, &order.symbol) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; @@ -70,12 +74,19 @@ impl Broker for PaperAlpacaBroker { let total_cost = (pos.quantity * pos.avg_entry_price) + order_value; let new_avg = total_cost / total_qty; self.storage - .upsert_position(&order.symbol, total_qty, new_avg, current_price) + .upsert_position( + &self.account_id, + &order.symbol, + total_qty, + new_avg, + current_price, + ) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; } else { self.storage .upsert_position( + &self.account_id, &order.symbol, order.quantity, current_price, @@ -88,7 +99,7 @@ impl Broker for PaperAlpacaBroker { Side::Sell => { let existing = self .storage - .get_position(&order.symbol) + .get_position(&self.account_id, &order.symbol) .await .map_err(|e| BrokerError::Fail(e.to_string()))? .ok_or_else(|| BrokerError::InvalidOrder("No position".to_string()))?; @@ -98,19 +109,20 @@ impl Broker for PaperAlpacaBroker { } self.storage - .add_cash(order_value) + .add_cash(&self.account_id, order_value) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; let new_qty = existing.quantity - order.quantity; if new_qty <= 0.0 { self.storage - .delete_position(&order.symbol) + .delete_position(&self.account_id, &order.symbol) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; } else { self.storage .upsert_position( + &self.account_id, &order.symbol, new_qty, existing.avg_entry_price, @@ -128,12 +140,13 @@ impl Broker for PaperAlpacaBroker { }; self.storage .insert_trade( + &self.account_id, &order.symbol, side_str, order.quantity, current_price, Some(&order_id), - None, // No strategy name in broker - that's engine's concern + None, ) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; @@ -150,22 +163,11 @@ impl Broker for PaperAlpacaBroker { for (symbol, bar) in &bars { self.storage - .update_position_price(symbol, bar.close) + .update_position_price(&self.account_id, symbol, bar.close) .await .map_err(|e| BrokerError::Fail(e.to_string()))?; } - let account = self - .storage - .get_account() - .await - .map_err(|e| BrokerError::Fail(e.to_string()))?; - - self.storage - .insert_account_snapshot(account.equity, account.cash, account.buying_power) - .await - .map_err(|e| BrokerError::Fail(e.to_string()))?; - Ok(()) } } diff --git a/src/config/mod.rs b/src/config/mod.rs index a117954..ce55b9a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,7 +6,6 @@ use serde::Deserialize; pub struct Config { pub broker: BrokerConfig, pub data: DataConfig, - pub strategies: Vec, pub risk: RiskConfig, pub logging: LoggingConfig, pub storage: StorageConfig, diff --git a/src/data/mod.rs b/src/data/mod.rs index 0f1c689..1c96a5d 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -64,6 +64,7 @@ struct ApiMultiBarsHistoryResponse { bars: Option>>, } +#[derive(Debug, Clone)] pub struct MarketData { client: Client, data_endpoint: String, diff --git a/src/engine/runner.rs b/src/engine/runner.rs index b7c3457..22d0cf4 100644 --- a/src/engine/runner.rs +++ b/src/engine/runner.rs @@ -9,7 +9,7 @@ use crate::{ logging::Logger, risk::RiskManager, storage::Storage, - strategy::{Strategy, StrategyParams, StrategySettings, create_strategy}, + strategy::{Strategy, StrategySettings, create_strategy}, types::{Side, Signal}, }; @@ -19,81 +19,198 @@ pub struct SignalMessage { pub signal: Signal, } -/// Shared data accessible by any component pub struct SharedState { pub storage: Storage, - pub broker: Arc, pub data: MarketData, pub logger: Arc, } -/// Main trading engine +pub struct AccountRunner { + pub broker: Arc, + pub strategies: Vec>, +} + pub struct Engine { state: Arc, - strategies: Vec>, + accounts: HashMap, risk_manager: RiskManager, _signal_tx: mpsc::Sender, _signal_rx: mpsc::Receiver, - strategy_configs: Vec, } impl Engine { pub async fn new( storage: Storage, - broker: Arc, data: MarketData, logger: Arc, risk_config: RiskConfig, - strategy_configs: Vec, ) -> Self { let (signal_tx, signal_rx) = mpsc::channel(100); let state = Arc::new(SharedState { storage, - broker, data, logger, }); Self { state, - strategies: Vec::new(), + accounts: HashMap::new(), risk_manager: RiskManager::new(risk_config), _signal_tx: signal_tx, _signal_rx: signal_rx, - strategy_configs, } } - /// Register a strategy - pub fn add_strategy(&mut self, config: StrategySettings) { - let strategy = create_strategy(&config); + async fn reload_from_db(&mut self) { + for (account_id, runner) in &mut self.accounts { + let assignments = match self + .state + .storage + .get_strategies_for_account(account_id) + .await + { + Ok(a) => a, + Err(e) => { + self.state + .logger + .error(&format!("[{}] Failed to reload: {}", account_id, e)); + continue; + } + }; + + let db_fingerprint: Vec<(String, String, String)> = assignments + .iter() + .map(|(s, a)| { + ( + s.name.clone(), + a.symbols_json.clone(), + s.params_json.clone(), + ) + }) + .collect(); + + let current_fingerprint: Vec<(String, String, String)> = runner + .strategies + .iter() + .map(|s| { + let symbols = serde_json::to_string(s.symbols()).unwrap_or_default(); + let params = serde_json::to_string(&s.params()).unwrap_or_default(); + (s.name().to_string(), symbols, params) + }) + .collect(); + + if db_fingerprint == current_fingerprint { + continue; + } + + self.state + .logger + .info(&format!("[{}] Strategies changed, reloading", account_id)); + + let mut new_strategies: Vec> = Vec::new(); + + for (db_strat, assignment) in assignments { + let params = match db_strat.params() { + Ok(p) => p, + Err(e) => { + self.state.logger.error(&format!( + "[{}] Bad params for '{}': {}", + account_id, db_strat.name, e + )); + continue; + } + }; + + let symbols: Vec = + serde_json::from_str(&assignment.symbols_json).unwrap_or_default(); + + let settings = StrategySettings { + name: db_strat.name.clone(), + symbols, + params, + }; + + new_strategies.push(create_strategy(&settings)); + } + + runner.strategies = new_strategies; + } + } + + pub async fn load_strategies( + &mut self, + account_id: &str, + broker: Arc, + ) -> Result<(), String> { + let assignments = self + .state + .storage + .get_strategies_for_account(account_id) + .await + .map_err(|e| format!("Failed to load strategies: {}", e))?; + + let mut strategies: Vec> = Vec::new(); + + for (db_strat, assignment) in assignments { + let params = db_strat + .params() + .map_err(|e| format!("Bad params for '{}': {}", db_strat.name, e))?; + + let symbols: Vec = + serde_json::from_str(&assignment.symbols_json).unwrap_or_default(); + + let settings = StrategySettings { + name: db_strat.name.clone(), + symbols, + params, + }; + + let strategy = create_strategy(&settings); + self.state.logger.info(&format!( + "[{}] Loaded strategy: {} for {:?}", + account_id, + strategy.name(), + strategy.symbols(), + )); + strategies.push(strategy); + } + self.state.logger.info(&format!( - "Registered strategy: {} for {:?}", - strategy.name(), - strategy.symbols() + "[{}] {} strategies loaded", + account_id, + strategies.len() )); - self.strategies.push(strategy); + self.accounts + .insert(account_id.to_string(), AccountRunner { broker, strategies }); + Ok(()) } - /// Main loop - pub async fn run(&mut self, poll_interval_secs: u64) { - self.state.logger.info("Engine starting..."); - - let mut all_symbols: Vec = Vec::new(); - for strategy in &self.strategies { - for symbol in strategy.symbols() { - if !all_symbols.contains(symbol) { - all_symbols.push(symbol.clone()); + fn all_symbols(&self) -> Vec { + let mut symbols = Vec::new(); + for runner in self.accounts.values() { + for strategy in &runner.strategies { + for symbol in strategy.symbols() { + if !symbols.contains(symbol) { + symbols.push(symbol.clone()); + } } } } + symbols + } - self.state - .logger - .info(&format!("Tracking symbols: {:?}", all_symbols)); + pub async fn run(&mut self, poll_interval_secs: u64) { + self.state.logger.info("Engine starting..."); + + let all_symbols = self.all_symbols(); + self.state.logger.info(&format!( + "Tracking {} symbols across {} accounts: {:?}", + all_symbols.len(), + self.accounts.len(), + all_symbols + )); - // Prefetch historical data self.fetch_historical_data(&all_symbols).await; let mut ticker = interval(Duration::from_secs(poll_interval_secs)); @@ -101,9 +218,14 @@ impl Engine { loop { self.state.logger.info("=== Tick starting ==="); - self.state.broker.update_prices(&all_symbols).await.ok(); + self.reload_from_db().await; + + let all_symbols = self.all_symbols(); + + for runner in self.accounts.values() { + runner.broker.update_prices(&all_symbols).await.ok(); + } - // Fetch latest data for all symbols self.state.logger.info("Fetching latest bars..."); let bars = match self.state.data.get_latest_bars(&all_symbols).await { Ok(b) => { @@ -122,49 +244,88 @@ impl Engine { } }; - // Collect signals first - let mut signals: Vec<(String, Signal, f64)> = Vec::new(); - - for strategy in &mut self.strategies { - for symbol in strategy.symbols().to_vec() { - if let Some(bar) = bars.get(&symbol) - && let Some(signal) = strategy.on_bar(bar) - { - match &signal { - Signal::Buy { .. } | Signal::Sell { .. } => { - signals.push((strategy.name().to_string(), signal, bar.close)); + let mut all_signals: Vec<(String, String, Signal, f64)> = Vec::new(); + + for (account_id, runner) in &mut self.accounts { + for strategy in runner.strategies.iter_mut() { + for symbol in strategy.symbols().to_vec() { + if let Some(bar) = bars.get(&symbol) + && let Some(signal) = strategy.on_bar(bar) + { + match &signal { + Signal::Buy { .. } | Signal::Sell { .. } => { + all_signals.push(( + account_id.clone(), + strategy.name().to_string(), + signal, + bar.close, + )); + } + Signal::Hold => {} } - Signal::Hold => {} } } } } - self.state - .logger - .info(&format!("Generated {} signals", signals.len())); + self.state.logger.info(&format!( + "Generated {} signals across {} accounts", + all_signals.len(), + self.accounts.len(), + )); - // Now process signals - for (strategy_name, signal, price) in signals { - self.process_signal(&strategy_name, &signal, price).await; + for (account_id, strategy_name, signal, price) in all_signals { + self.process_signal(&account_id, &strategy_name, &signal, price) + .await; } - // Save account snapshot - match self.state.broker.get_account().await { - Ok(account) => { - let _ = self - .state - .storage - .insert_account_snapshot(account.equity, account.cash, account.buying_power) - .await; - self.state - .logger - .info(&format!("Account: ${:.2}", account.equity)); + for (account_id, runner) in &self.accounts { + match runner.broker.get_account().await { + Ok(account) => { + let _ = self + .state + .storage + .insert_account_snapshot( + account_id, + account.equity, + account.cash, + account.buying_power, + ) + .await; + self.state + .logger + .info(&format!("[{}] Account: ${:.2}", account_id, account.equity)); + } + Err(e) => { + self.state + .logger + .error(&format!("[{}] Failed to fetch account: {}", account_id, e)); + } } - Err(e) => { - self.state - .logger - .error(&format!("Failed to fetch account: {}", e)); + } + + for (account_id, runner) in &self.accounts { + match runner.broker.get_positions().await { + Ok(positions) => { + for pos in &positions { + let _ = self + .state + .storage + .upsert_position( + account_id, + &pos.symbol, + pos.quantity, + pos.avg_entry_price, + pos.current_price, + ) + .await; + } + } + Err(e) => { + self.state + .logger + .error(&format!("[{}] Failed to sync positions: {}", account_id, e)); + } } } @@ -176,25 +337,34 @@ impl Engine { } } - async fn process_signal(&mut self, strategy_name: &str, signal: &Signal, current_price: f64) { - // Log signal + async fn process_signal( + &mut self, + account_id: &str, + strategy_name: &str, + signal: &Signal, + current_price: f64, + ) { + let runner = match self.accounts.get(account_id) { + Some(r) => r, + None => return, + }; + match signal { Signal::Buy { symbol, strength } => { self.state.logger.info(&format!( - "[{}] BUY signal for {} (strength: {:.2})", - strategy_name, symbol, strength + "[{}][{}] BUY {} (strength: {:.2})", + account_id, strategy_name, symbol, strength )); } Signal::Sell { symbol, strength } => { self.state.logger.info(&format!( - "[{}] SELL signal for {} (strength: {:.2})", - strategy_name, symbol, strength + "[{}][{}] SELL {} (strength: {:.2})", + account_id, strategy_name, symbol, strength )); } Signal::Hold => return, } - // Save signal to DB let (symbol, signal_type, strength) = match signal { Signal::Buy { symbol, strength } => (symbol.clone(), "buy", *strength), Signal::Sell { symbol, strength } => (symbol.clone(), "sell", *strength), @@ -204,48 +374,45 @@ impl Engine { let _ = self .state .storage - .insert_signal(&symbol, signal_type, strength, strategy_name) + .insert_signal(account_id, &symbol, signal_type, strength, strategy_name) .await; - // Get account and positions for risk check - let account = match self.state.broker.get_account().await { + let account = match runner.broker.get_account().await { Ok(a) => a, Err(e) => { self.state .logger - .error(&format!("Failed to fetch account: {}", e)); + .error(&format!("[{}] Failed to fetch account: {}", account_id, e)); return; } }; - let positions = match self.state.broker.get_positions().await { + let positions = match runner.broker.get_positions().await { Ok(p) => p, Err(e) => { - self.state - .logger - .error(&format!("Failed to fetch positions: {}", e)); + self.state.logger.error(&format!( + "[{}] Failed to fetch positions: {}", + account_id, e + )); return; } }; - // Run through risk manager if let Some(order) = self.risk_manager .evaluate_signal(signal, &account, &positions, current_price) { self.state.logger.info(&format!( - "Executing order: {:?} {} x {}", - order.side, order.symbol, order.quantity + "[{}] Executing: {:?} {} x {}", + account_id, order.side, order.symbol, order.quantity )); - // Execute order - match self.state.broker.submit_order(&order).await { + match runner.broker.submit_order(&order).await { Ok(order_id) => { self.state .logger - .info(&format!("Order filled: {}", order_id)); + .info(&format!("[{}] Order filled: {}", account_id, order_id)); - // Save trade let side = match order.side { Side::Buy => "buy", Side::Sell => "sell", @@ -255,6 +422,7 @@ impl Engine { .state .storage .insert_trade( + account_id, &order.symbol, side, order.quantity, @@ -265,7 +433,9 @@ impl Engine { .await; } Err(e) => { - self.state.logger.error(&format!("Order failed: {}", e)); + self.state + .logger + .error(&format!("[{}] Order failed: {}", account_id, e)); } } } @@ -276,31 +446,25 @@ impl Engine { .logger .info("Fetching historical data for all symbols..."); - // Group symbols by their timeframe/limit config let mut by_config: HashMap<(String, u32), Vec> = HashMap::new(); - for symbol in symbols { - let config = self - .strategy_configs - .iter() - .find(|c| c.symbols.contains(symbol)); - - let (timeframe, limit) = match config.map(|c| &c.params) { - Some(StrategyParams::Momentum { - startup_lookback, - startup_bar_limit, - .. - }) => (startup_lookback.clone(), *startup_bar_limit), - _ => ("1Hour".to_string(), 50), - }; - - by_config - .entry((timeframe, limit)) - .or_default() - .push(symbol.clone()); + for runner in self.accounts.values() { + for strategy in &runner.strategies { + let (timeframe, limit) = strategy + .startup_config() + .unwrap_or(("1Hour".to_string(), 50)); + + for symbol in strategy.symbols() { + if symbols.contains(symbol) { + let entry = by_config.entry((timeframe.clone(), limit)).or_default(); + if !entry.contains(symbol) { + entry.push(symbol.clone()); + } + } + } + } } - // Batch fetch for each config group for ((timeframe, limit), group_symbols) in by_config { match self .state @@ -317,10 +481,12 @@ impl Engine { timeframe )); - for bar in bars { - for strategy in &mut self.strategies { + for runner in self.accounts.values_mut() { + for strategy in runner.strategies.iter_mut() { if strategy.symbols().contains(symbol) { - let _ = strategy.on_bar(bar); + for bar in bars { + let _ = strategy.on_bar(bar); + } } } } diff --git a/src/main.rs b/src/main.rs index f8e042b..d11d37b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,6 @@ use cubert::data::MarketData; use cubert::engine::Engine; use cubert::logging::Logger; use cubert::storage::Storage; -use cubert::strategy::{StrategyParams, StrategySettings}; #[tokio::main] async fn main() { @@ -26,7 +25,6 @@ async fn main() { logger.info("=== Cubert Starting ==="); - // Engine storage (for logging signals, snapshots, etc.) let storage = match Storage::connect(&config.storage.db_url).await { Ok(s) => s, Err(e) => { @@ -40,96 +38,78 @@ async fn main() { std::process::exit(1); } - // Create broker based on config - let broker: Arc = if config.broker.paper { - logger.info("Using paper broker (simulated)"); - - // Broker storage (simulates broker API) - let broker_storage = match Storage::connect(&config.storage.broker_db_url).await { - Ok(s) => s, - Err(e) => { - logger.error(&format!("Broker database error: {}", e)); - std::process::exit(1); - } - }; + let data = MarketData::new( + &config.broker.data_endpoint, + &config.broker.api_key, + &config.broker.api_secret, + ); - if let Err(e) = broker_storage.migrate().await { - logger.error(&format!("Broker migration failed: {}", e)); + let account_ids = match storage.get_active_account_ids().await { + Ok(ids) => ids, + Err(e) => { + logger.error(&format!("Failed to load accounts: {}", e)); std::process::exit(1); } + }; - let starting_cash = 100_000.0; - if let Err(e) = broker_storage.init_account(starting_cash).await { - logger.error(&format!("Failed to init account: {}", e)); - std::process::exit(1); - } + logger.info(&format!("Found {} active accounts", account_ids.len())); + let mut engine = Engine::new(storage, data.clone(), logger.clone(), config.risk.clone()).await; + + for account_id in &account_ids { + let broker: Arc = if config.broker.paper { + let broker_db_url = format!("sqlite:broker_{}.db", account_id); + let broker_storage = match Storage::connect(&broker_db_url).await { + Ok(s) => s, + Err(e) => { + logger.error(&format!("[{}] Broker DB error: {}", account_id, e)); + continue; + } + }; + + if let Err(e) = broker_storage.migrate().await { + logger.error(&format!("[{}] Broker migration failed: {}", account_id, e)); + continue; + } - let data = MarketData::new( - &config.broker.data_endpoint, - &config.broker.api_key, - &config.broker.api_secret, - ); - - Arc::new(PaperAlpacaBroker::new(broker_storage, data)) - } else { - logger.info("Using Alpaca live broker"); - - Arc::new(AlpacaApiBroker::new( - &config.broker.api_endpoint, - &config.broker.api_key, - &config.broker.api_secret, - false, - )) - }; + if let Err(e) = broker_storage.init_account(account_id, 100_000.0).await { + logger.error(&format!("[{}] Failed to init account: {}", account_id, e)); + continue; + } - // Verify account - match broker.get_account().await { - Ok(account) => { - logger.info(&format!( - "Account: ${:.2} cash, ${:.2} equity", - account.cash, account.equity - )); - } - Err(e) => { - logger.error(&format!("Account error: {}", e)); - std::process::exit(1); - } - } + Arc::new(PaperAlpacaBroker::new( + account_id, + broker_storage, + data.clone(), + )) + } else { + Arc::new(AlpacaApiBroker::new( + account_id, + &config.broker.api_endpoint, + &config.broker.api_key, + &config.broker.api_secret, + false, + )) + }; - // Market data - let data = MarketData::new( - &config.broker.data_endpoint, - &config.broker.api_key, - &config.broker.api_secret, - ); + match broker.get_account().await { + Ok(account) => { + logger.info(&format!( + "[{}] Account: ${:.2} cash, ${:.2} equity", + account_id, account.cash, account.equity + )); + } + Err(e) => { + logger.error(&format!("[{}] Account error: {}", account_id, e)); + continue; + } + } - let strategies: Vec = config - .strategies - .iter() - .map(|s| StrategySettings { - name: s.name.clone(), - symbols: s.symbols.clone(), - params: StrategyParams::Momentum { - lookback_period: s.params.lookback_period, - threshold: s.params.threshold, - startup_lookback: s.params.startup_lookback.clone(), - startup_bar_limit: s.params.startup_bar_limit, - }, - }) - .collect(); - - let mut engine = Engine::new( - storage, - broker, - data, - logger.clone(), - config.risk.clone(), - strategies.clone(), - ) - .await; - - for strategy in strategies { - engine.add_strategy(strategy); + if let Err(e) = engine.load_strategies(account_id, broker).await { + logger.error(&format!( + "[{}] Failed to load strategies: {}", + account_id, e + )); + } } engine.run(60).await; diff --git a/src/storage/models.rs b/src/storage/models.rs index cc5cb16..437ee80 100644 --- a/src/storage/models.rs +++ b/src/storage/models.rs @@ -1,8 +1,14 @@ +use chrono::Utc; +use serde::{Deserialize, Serialize}; use sqlx::FromRow; +use uuid::Uuid; + +use crate::strategy::StrategyParams; #[derive(Debug, Clone, FromRow)] pub struct DbTrade { pub id: String, + pub account_id: String, pub symbol: String, pub side: String, pub quantity: f64, @@ -13,18 +19,20 @@ pub struct DbTrade { } #[derive(Debug, Clone, FromRow)] -pub struct DbPosition { +pub struct DbSignal { pub id: String, + pub account_id: String, pub symbol: String, - pub quantity: f64, - pub avg_entry_price: f64, - pub current_price: f64, - pub updated_at: String, + pub signal_type: String, + pub strength: f64, + pub strategy: String, + pub timestamp: String, } #[derive(Debug, Clone, FromRow)] pub struct DbAccountSnapshot { pub id: String, + pub account_id: String, pub equity: f64, pub cash: f64, pub buying_power: f64, @@ -32,13 +40,14 @@ pub struct DbAccountSnapshot { } #[derive(Debug, Clone, FromRow)] -pub struct DbSignal { +pub struct DbPosition { pub id: String, + pub account_id: String, pub symbol: String, - pub signal_type: String, - pub strength: f64, - pub strategy: String, - pub timestamp: String, + pub quantity: f64, + pub avg_entry_price: f64, + pub current_price: f64, + pub updated_at: String, } #[derive(Debug, Clone, FromRow)] @@ -48,3 +57,46 @@ pub struct DbAccount { pub created_at: String, pub updated_at: String, } + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct DbStrategy { + pub id: String, + pub name: String, + pub strategy_type: String, + pub params_json: String, + pub created_at: String, + pub updated_at: String, +} + +impl DbStrategy { + pub fn new(name: String, params: &StrategyParams) -> Self { + let now = Utc::now().to_rfc3339(); + let strategy_type = match params { + StrategyParams::Momentum { .. } => "momentum", + StrategyParams::MeanReversion { .. } => "mean_reversion", + _ => "custom", + }; + + Self { + id: Uuid::new_v4().to_string(), + name, + strategy_type: strategy_type.to_string(), + params_json: serde_json::to_string(params).unwrap_or_default(), + created_at: now.clone(), + updated_at: now, + } + } + + pub fn params(&self) -> Result { + serde_json::from_str(&self.params_json) + } +} + +#[derive(Debug, Clone, FromRow)] +pub struct DbAccountStrategy { + pub account_id: String, + pub strategy_id: String, + pub symbols_json: String, + pub enabled: bool, + pub assigned_at: String, +} diff --git a/src/storage/sqlite.rs b/src/storage/sqlite.rs index 99bc80c..395d417 100644 --- a/src/storage/sqlite.rs +++ b/src/storage/sqlite.rs @@ -3,7 +3,11 @@ use sqlx::{Row, SqlitePool, sqlite::SqlitePoolOptions}; use uuid::Uuid; use super::models::{DbAccountSnapshot, DbPosition, DbSignal, DbTrade}; -use crate::types::{Account, Position}; +use crate::{ + storage::{DbAccountStrategy, DbStrategy}, + strategy::StrategyParams, + types::{Account, Position}, +}; #[derive(Debug)] pub enum StorageError { @@ -24,15 +28,23 @@ impl std::fmt::Display for StorageError { impl std::error::Error for StorageError {} +#[derive(Clone)] pub struct Storage { pool: SqlitePool, } impl Storage { pub async fn connect(database_url: &str) -> Result { + let url = if database_url.starts_with("sqlite:") && !database_url.contains("mode=") { + let separator = if database_url.contains('?') { "&" } else { "?" }; + format!("{}{}mode=rwc", database_url, separator) + } else { + database_url.to_string() + }; + let pool = SqlitePoolOptions::new() .max_connections(5) - .connect(database_url) + .connect(&url) .await .map_err(|e| StorageError::Connection(e.to_string()))?; @@ -40,11 +52,11 @@ impl Storage { } pub async fn migrate(&self) -> Result<(), StorageError> { - // Trades table sqlx::query( r#" CREATE TABLE IF NOT EXISTS trades ( id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, symbol TEXT NOT NULL, side TEXT NOT NULL, quantity REAL NOT NULL, @@ -59,16 +71,17 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; - // Positions table sqlx::query( r#" CREATE TABLE IF NOT EXISTS positions ( id TEXT PRIMARY KEY, - symbol TEXT NOT NULL UNIQUE, + account_id TEXT NOT NULL, + symbol TEXT NOT NULL, quantity REAL NOT NULL, avg_entry_price REAL NOT NULL, current_price REAL NOT NULL, - updated_at TEXT NOT NULL + updated_at TEXT NOT NULL, + UNIQUE(account_id, symbol) ) "#, ) @@ -76,11 +89,11 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; - // Account snapshots table (historical) sqlx::query( r#" CREATE TABLE IF NOT EXISTS account_snapshots ( id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, equity REAL NOT NULL, cash REAL NOT NULL, buying_power REAL NOT NULL, @@ -92,11 +105,11 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; - // Signals table sqlx::query( r#" CREATE TABLE IF NOT EXISTS signals ( id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, symbol TEXT NOT NULL, signal_type TEXT NOT NULL, strength REAL NOT NULL, @@ -109,7 +122,6 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; - // Account table (current state - single row) sqlx::query( r#" CREATE TABLE IF NOT EXISTS account ( @@ -124,7 +136,39 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; - // Indexes + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS strategies ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + strategy_type TEXT NOT NULL, + params_json TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + "#, + ) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS account_strategies ( + account_id TEXT NOT NULL, + strategy_id TEXT NOT NULL, + symbols_json TEXT NOT NULL DEFAULT '[]', + enabled BOOLEAN DEFAULT TRUE, + assigned_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (account_id, strategy_id), + FOREIGN KEY (strategy_id) REFERENCES strategies(id) + ) + "#, + ) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_trades_timestamp ON trades(timestamp DESC)") .execute(&self.pool) .await @@ -135,29 +179,45 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string()))?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_trades_account ON trades(account_id)") + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + sqlx::query("CREATE INDEX IF NOT EXISTS idx_signals_account ON signals(account_id)") + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + sqlx::query( + "CREATE INDEX IF NOT EXISTS idx_snapshots_account ON account_snapshots(account_id)", + ) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + Ok(()) } // ============ Account ============ - /// Initialize account with starting cash (only if not exists) - pub async fn init_account(&self, starting_cash: f64) -> Result<(), StorageError> { - let existing = sqlx::query("SELECT id FROM account LIMIT 1") + pub async fn init_account( + &self, + account_id: &str, + starting_cash: f64, + ) -> Result<(), StorageError> { + let existing = sqlx::query("SELECT id FROM account WHERE id = ?") + .bind(account_id) .fetch_optional(&self.pool) .await .map_err(|e| StorageError::Query(e.to_string()))?; if existing.is_none() { - let id = Uuid::new_v4().to_string(); let now = Utc::now().to_rfc3339(); - sqlx::query( - r#" - INSERT INTO account (id, cash, created_at, updated_at) - VALUES (?, ?, ?, ?) - "#, + "INSERT INTO account (id, cash, created_at, updated_at) VALUES (?, ?, ?, ?)", ) - .bind(&id) + .bind(account_id) .bind(starting_cash) .bind(&now) .bind(&now) @@ -169,71 +229,66 @@ impl Storage { Ok(()) } - /// Get current account state (cash + positions = equity) - pub async fn get_account(&self) -> Result { - let row = sqlx::query("SELECT cash FROM account LIMIT 1") + pub async fn get_account(&self, account_id: &str) -> Result { + let row = sqlx::query("SELECT id, cash FROM account WHERE id = ?") + .bind(account_id) .fetch_optional(&self.pool) .await .map_err(|e| StorageError::Query(e.to_string()))? .ok_or(StorageError::NotFound)?; let cash: f64 = row.get("cash"); - - // Calculate equity from positions - let positions = self.get_positions().await?; + let positions = self.get_positions(account_id).await?; let positions_value: f64 = positions.iter().map(|p| p.quantity * p.current_price).sum(); - let equity = cash + positions_value; Ok(Account { + id: account_id.to_string(), equity, cash, - buying_power: cash, // Simplified: buying power = cash + buying_power: cash, }) } - /// Update account cash - pub async fn update_cash(&self, new_cash: f64) -> Result<(), StorageError> { + pub async fn update_cash(&self, account_id: &str, new_cash: f64) -> Result<(), StorageError> { let now = Utc::now().to_rfc3339(); - - sqlx::query("UPDATE account SET cash = ?, updated_at = ?") + sqlx::query("UPDATE account SET cash = ?, updated_at = ? WHERE id = ?") .bind(new_cash) .bind(&now) + .bind(account_id) .execute(&self.pool) .await .map_err(|e| StorageError::Query(e.to_string()))?; - Ok(()) } - /// Deduct cash (for buying) - pub async fn deduct_cash(&self, amount: f64) -> Result { - let account = self.get_account().await?; + pub async fn deduct_cash(&self, account_id: &str, amount: f64) -> Result { + let account = self.get_account(account_id).await?; let new_cash = account.cash - amount; - if new_cash < 0.0 { return Err(StorageError::Query("Insufficient funds".to_string())); } - - self.update_cash(new_cash).await?; + self.update_cash(account_id, new_cash).await?; Ok(new_cash) } - /// Add cash (for selling) - pub async fn add_cash(&self, amount: f64) -> Result { - let account = self.get_account().await?; + pub async fn add_cash(&self, account_id: &str, amount: f64) -> Result { + let account = self.get_account(account_id).await?; let new_cash = account.cash + amount; - self.update_cash(new_cash).await?; + self.update_cash(account_id, new_cash).await?; Ok(new_cash) } // ============ Positions ============ - pub async fn get_positions(&self) -> Result, StorageError> { - let rows = sqlx::query_as::<_, DbPosition>("SELECT * FROM positions WHERE quantity != 0") - .fetch_all(&self.pool) - .await - .map_err(|e| StorageError::Query(e.to_string()))?; + pub async fn get_positions(&self, account_id: &str) -> Result, StorageError> { + let rows = sqlx::query_as::<_, DbPosition>( + "SELECT * FROM positions WHERE account_id = ? AND quantity != 0", + ) + .bind(account_id) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; Ok(rows .into_iter() @@ -246,10 +301,15 @@ impl Storage { .collect()) } - pub async fn get_position(&self, symbol: &str) -> Result, StorageError> { + pub async fn get_position( + &self, + account_id: &str, + symbol: &str, + ) -> Result, StorageError> { let row = sqlx::query_as::<_, DbPosition>( - "SELECT * FROM positions WHERE symbol = ? AND quantity != 0", + "SELECT * FROM positions WHERE account_id = ? AND symbol = ? AND quantity != 0", ) + .bind(account_id) .bind(symbol) .fetch_optional(&self.pool) .await @@ -265,6 +325,7 @@ impl Storage { pub async fn upsert_position( &self, + account_id: &str, symbol: &str, quantity: f64, avg_entry_price: f64, @@ -275,9 +336,9 @@ impl Storage { sqlx::query( r#" - INSERT INTO positions (id, symbol, quantity, avg_entry_price, current_price, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(symbol) DO UPDATE SET + INSERT INTO positions (id, account_id, symbol, quantity, avg_entry_price, current_price, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(account_id, symbol) DO UPDATE SET quantity = excluded.quantity, avg_entry_price = excluded.avg_entry_price, current_price = excluded.current_price, @@ -285,6 +346,7 @@ impl Storage { "#, ) .bind(&id) + .bind(account_id) .bind(symbol) .bind(quantity) .bind(avg_entry_price) @@ -299,14 +361,16 @@ impl Storage { pub async fn update_position_price( &self, + account_id: &str, symbol: &str, current_price: f64, ) -> Result<(), StorageError> { let updated_at = Utc::now().to_rfc3339(); - sqlx::query("UPDATE positions SET current_price = ?, updated_at = ? WHERE symbol = ?") + sqlx::query("UPDATE positions SET current_price = ?, updated_at = ? WHERE account_id = ? AND symbol = ?") .bind(current_price) .bind(&updated_at) + .bind(account_id) .bind(symbol) .execute(&self.pool) .await @@ -315,8 +379,13 @@ impl Storage { Ok(()) } - pub async fn delete_position(&self, symbol: &str) -> Result<(), StorageError> { - sqlx::query("DELETE FROM positions WHERE symbol = ?") + pub async fn delete_position( + &self, + account_id: &str, + symbol: &str, + ) -> Result<(), StorageError> { + sqlx::query("DELETE FROM positions WHERE account_id = ? AND symbol = ?") + .bind(account_id) .bind(symbol) .execute(&self.pool) .await @@ -327,8 +396,10 @@ impl Storage { // ============ Trades ============ + #[allow(clippy::too_many_arguments)] pub async fn insert_trade( &self, + account_id: &str, symbol: &str, side: &str, quantity: f64, @@ -341,11 +412,12 @@ impl Storage { sqlx::query( r#" - INSERT INTO trades (id, symbol, side, quantity, price, timestamp, order_id, strategy) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO trades (id, account_id, symbol, side, quantity, price, timestamp, order_id, strategy) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&id) + .bind(account_id) .bind(symbol) .bind(side) .bind(quantity) @@ -368,6 +440,21 @@ impl Storage { .map_err(|e| StorageError::Query(e.to_string())) } + pub async fn get_trades_by_account( + &self, + account_id: &str, + limit: i32, + ) -> Result, StorageError> { + sqlx::query_as::<_, DbTrade>( + "SELECT * FROM trades WHERE account_id = ? ORDER BY timestamp DESC LIMIT ?", + ) + .bind(account_id) + .bind(limit) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string())) + } + pub async fn get_trades_by_symbol( &self, symbol: &str, @@ -387,6 +474,7 @@ impl Storage { pub async fn insert_account_snapshot( &self, + account_id: &str, equity: f64, cash: f64, buying_power: f64, @@ -396,11 +484,12 @@ impl Storage { sqlx::query( r#" - INSERT INTO account_snapshots (id, equity, cash, buying_power, timestamp) - VALUES (?, ?, ?, ?, ?) + INSERT INTO account_snapshots (id, account_id, equity, cash, buying_power, timestamp) + VALUES (?, ?, ?, ?, ?, ?) "#, ) .bind(&id) + .bind(account_id) .bind(equity) .bind(cash) .bind(buying_power) @@ -414,11 +503,13 @@ impl Storage { pub async fn get_account_history( &self, + account_id: &str, limit: i32, ) -> Result, StorageError> { sqlx::query_as::<_, DbAccountSnapshot>( - "SELECT * FROM account_snapshots ORDER BY timestamp DESC LIMIT ?", + "SELECT * FROM account_snapshots WHERE account_id = ? ORDER BY timestamp DESC LIMIT ?", ) + .bind(account_id) .bind(limit) .fetch_all(&self.pool) .await @@ -429,6 +520,7 @@ impl Storage { pub async fn insert_signal( &self, + account_id: &str, symbol: &str, signal_type: &str, strength: f64, @@ -439,11 +531,12 @@ impl Storage { sqlx::query( r#" - INSERT INTO signals (id, symbol, signal_type, strength, strategy, timestamp) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO signals (id, account_id, symbol, signal_type, strength, strategy, timestamp) + VALUES (?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&id) + .bind(account_id) .bind(symbol) .bind(signal_type) .bind(strength) @@ -463,4 +556,222 @@ impl Storage { .await .map_err(|e| StorageError::Query(e.to_string())) } + + pub async fn get_signals_by_account( + &self, + account_id: &str, + limit: i32, + ) -> Result, StorageError> { + sqlx::query_as::<_, DbSignal>( + "SELECT * FROM signals WHERE account_id = ? ORDER BY timestamp DESC LIMIT ?", + ) + .bind(account_id) + .bind(limit) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string())) + } + + // ============ Strategies ============ + + pub async fn insert_strategy( + &self, + name: &str, + params: &StrategyParams, + ) -> Result { + let id = Uuid::new_v4().to_string(); + let timestamp = Utc::now().to_rfc3339(); + let params_json = + serde_json::to_string(params).map_err(|e| StorageError::Query(e.to_string()))?; + + let strategy_type = match params { + StrategyParams::Momentum { .. } => "Momentum", + StrategyParams::MeanReversion { .. } => "MeanReversion", + _ => "Custom", + }; + + sqlx::query( + "INSERT INTO strategies (id, name, strategy_type, params_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&id) + .bind(name) + .bind(strategy_type) + .bind(params_json) + .bind(×tamp) + .bind(×tamp) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(id) + } + + pub async fn get_strategies(&self) -> Result, StorageError> { + let strategies: Vec = sqlx::query_as("SELECT * FROM strategies") + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(strategies) + } + + pub async fn get_strategy_by_id(&self, id: &str) -> Result { + sqlx::query_as::<_, DbStrategy>("SELECT * FROM strategies WHERE id = ?") + .bind(id) + .fetch_one(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string())) + } + + pub async fn update_strategy( + &self, + id: &str, + name: &str, + strategy_type: &str, + params_json: &str, + ) -> Result<(), StorageError> { + let now = Utc::now().to_rfc3339(); + + let result = sqlx::query( + "UPDATE strategies SET name = ?, strategy_type = ?, params_json = ?, updated_at = ? WHERE id = ?", + ) + .bind(name) + .bind(strategy_type) + .bind(params_json) + .bind(&now) + .bind(id) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + if result.rows_affected() == 0 { + return Err(StorageError::NotFound); + } + Ok(()) + } + + pub async fn delete_strategy(&self, id: &str) -> Result<(), StorageError> { + sqlx::query("DELETE FROM account_strategies WHERE strategy_id = ?") + .bind(id) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + sqlx::query("DELETE FROM strategies WHERE id = ?") + .bind(id) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(()) + } + + // ============ Account Strategies ============ + + pub async fn assign_strategy_to_account( + &self, + account_id: &str, + strategy_id: &str, + symbols: &[String], + ) -> Result<(), StorageError> { + let symbols_json = serde_json::to_string(symbols).unwrap_or_default(); + + sqlx::query( + "INSERT OR REPLACE INTO account_strategies (account_id, strategy_id, symbols_json, enabled) + VALUES (?, ?, ?, TRUE)", + ) + .bind(account_id) + .bind(strategy_id) + .bind(&symbols_json) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(()) + } + + pub async fn unassign_strategy_from_account( + &self, + account_id: &str, + strategy_id: &str, + ) -> Result<(), StorageError> { + sqlx::query("DELETE FROM account_strategies WHERE account_id = ? AND strategy_id = ?") + .bind(account_id) + .bind(strategy_id) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(()) + } + + pub async fn get_strategies_for_account( + &self, + account_id: &str, + ) -> Result, StorageError> { + let assignments = sqlx::query_as::<_, DbAccountStrategy>( + "SELECT * FROM account_strategies WHERE account_id = ? AND enabled = TRUE", + ) + .bind(account_id) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + let mut results = Vec::new(); + for assignment in assignments { + let strategy = sqlx::query_as::<_, DbStrategy>("SELECT * FROM strategies WHERE id = ?") + .bind(&assignment.strategy_id) + .fetch_one(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + results.push((strategy, assignment)); + } + + Ok(results) + } + + pub async fn get_accounts_for_strategy( + &self, + strategy_id: &str, + ) -> Result, StorageError> { + sqlx::query_as::<_, DbAccountStrategy>( + "SELECT * FROM account_strategies WHERE strategy_id = ? AND enabled = TRUE", + ) + .bind(strategy_id) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string())) + } + + pub async fn toggle_strategy_enabled( + &self, + account_id: &str, + strategy_id: &str, + enabled: bool, + ) -> Result<(), StorageError> { + sqlx::query( + "UPDATE account_strategies SET enabled = ? WHERE account_id = ? AND strategy_id = ?", + ) + .bind(enabled) + .bind(account_id) + .bind(strategy_id) + .execute(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(()) + } + + pub async fn get_active_account_ids(&self) -> Result, StorageError> { + let rows = sqlx::query_scalar::<_, String>( + "SELECT DISTINCT account_id FROM account_strategies WHERE enabled = TRUE", + ) + .fetch_all(&self.pool) + .await + .map_err(|e| StorageError::Query(e.to_string()))?; + + Ok(rows) + } } diff --git a/src/strategy/mod.rs b/src/strategy/mod.rs index c11a52a..743cbb0 100644 --- a/src/strategy/mod.rs +++ b/src/strategy/mod.rs @@ -2,6 +2,7 @@ pub mod momentum; use crate::types::{Bar, Signal}; use async_trait::async_trait; +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone)] pub struct StrategySettings { @@ -10,7 +11,8 @@ pub struct StrategySettings { pub params: StrategyParams, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] pub enum StrategyParams { Momentum { lookback_period: usize, @@ -33,6 +35,10 @@ pub trait Strategy: Send + Sync { fn symbols(&self) -> &[String]; fn on_bar(&mut self, bar: &Bar) -> Option; fn reset(&mut self); + fn startup_config(&self) -> Option<(String, u32)> { + None + } + fn params(&self) -> StrategyParams; } pub fn create_strategy(config: &StrategySettings) -> Box { @@ -40,12 +46,15 @@ pub fn create_strategy(config: &StrategySettings) -> Box { StrategyParams::Momentum { lookback_period, threshold, - .. + startup_lookback, + startup_bar_limit, } => Box::new(momentum::MomentumStrategy::new( &config.name, config.symbols.clone(), *lookback_period, *threshold, + startup_lookback.clone(), + *startup_bar_limit, )), StrategyParams::MeanReversion { window: _, diff --git a/src/strategy/momentum.rs b/src/strategy/momentum.rs index 34dc6f8..8bf0daa 100644 --- a/src/strategy/momentum.rs +++ b/src/strategy/momentum.rs @@ -1,39 +1,47 @@ +use std::collections::HashMap; + use crate::strategy::Bar; use crate::strategy::Strategy; +use crate::strategy::StrategyParams; use crate::types::Signal; -use std::collections::HashMap; pub struct MomentumStrategy { name: String, symbols: Vec, lookback_period: usize, threshold: f64, - history: HashMap>, + startup_lookback: String, + startup_bar_limit: u32, + price_history: HashMap>, } impl MomentumStrategy { - pub fn new(name: &str, symbols: Vec, lookback_period: usize, threshold: f64) -> Self { - let mut history = HashMap::new(); - for symbol in &symbols { - history.insert(symbol.clone(), Vec::with_capacity(lookback_period + 1)); - } - + pub fn new( + name: &str, + symbols: Vec, + lookback_period: usize, + threshold: f64, + startup_lookback: String, + startup_bar_limit: u32, + ) -> Self { Self { name: name.to_string(), symbols, lookback_period, threshold, - history, + startup_lookback, + startup_bar_limit, + price_history: HashMap::new(), } } - fn calculate_momentum(&self, prices: &[f64]) -> Option { - if prices.len() < self.lookback_period { + fn calculate_momentum(prices: &[f64], lookback_period: usize) -> Option { + if prices.len() < lookback_period { return None; } let current = *prices.last()?; - let past = prices[prices.len() - self.lookback_period]; + let past = prices[prices.len() - lookback_period]; if past == 0.0 { return None; @@ -51,18 +59,15 @@ impl Strategy for MomentumStrategy { &self.symbols } fn on_bar(&mut self, bar: &Bar) -> Option { - { - let history = self.history.get_mut(&bar.symbol)?; + let history = self.price_history.entry(bar.symbol.clone()).or_default(); - history.push(bar.close); + history.push(bar.close); - if history.len() > self.lookback_period + 10 { - history.drain(0..10); - } + if history.len() > self.lookback_period + 10 { + history.drain(0..10); } - let history = self.history.get(&bar.symbol)?; - let momentum = self.calculate_momentum(history)?; + let momentum = MomentumStrategy::calculate_momentum(history, self.lookback_period)?; if momentum > self.threshold { Some(Signal::Buy { @@ -80,8 +85,21 @@ impl Strategy for MomentumStrategy { } fn reset(&mut self) { - for history in self.history.values_mut() { + for history in self.price_history.values_mut() { history.clear(); } } + + fn startup_config(&self) -> Option<(String, u32)> { + Some((self.startup_lookback.clone(), self.startup_bar_limit)) + } + + fn params(&self) -> StrategyParams { + StrategyParams::Momentum { + lookback_period: self.lookback_period, + threshold: self.threshold, + startup_lookback: self.startup_lookback.clone(), + startup_bar_limit: self.startup_bar_limit, + } + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 9e9aeb7..80b5faf 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -76,6 +76,7 @@ impl Position { #[derive(Debug, Clone)] pub struct Account { + pub id: String, pub equity: f64, pub cash: f64, pub buying_power: f64, diff --git a/streamlit/app.py b/streamlit/app.py index 5edac59..cc582d1 100644 --- a/streamlit/app.py +++ b/streamlit/app.py @@ -4,135 +4,106 @@ import plotly.graph_objects as go from utils.db import ( get_account_stats, - get_positions, get_signals, - get_broker_trades, - get_engine_trades, + get_trades, get_account_history, get_strategy_performance, get_signal_stats, + get_active_account_ids, ) st.set_page_config(page_title="Cubert Dashboard", layout="wide") -st.title("🤖 Cubert Trading Dashboard") +st.title("Cubert Trading Dashboard") -# Auto-refresh -if st.button("🔄 Refresh"): +if st.button("Refresh"): st.rerun() +account_ids = get_active_account_ids() + +if not account_ids: + st.info("No active accounts. Go to Accounts to create one.") + st.stop() + +account_id = st.selectbox("Account", account_ids) + +stats = get_account_stats(account_id) + +if not stats: + st.info(f"No data for {account_id} yet. The engine needs to run at least one tick.") + st.stop() + # ============ Account Overview ============ -st.header("📊 Account Overview") +st.header("Account Overview") -stats = get_account_stats() col1, col2, col3, col4 = st.columns(4) - -with col1: - st.metric("Equity", f"${stats['equity']:,.2f}") -with col2: - st.metric("Cash", f"${stats['cash']:,.2f}") -with col3: - st.metric("Unrealized P&L", f"${stats['unrealized_pnl']:,.2f}") -with col4: - st.metric("Positions", stats['positions_count']) - -col5, col6 = st.columns(2) -with col5: - st.metric("Total Trades", stats['total_trades']) -with col6: - st.metric("Total Signals", stats['total_signals']) +col1.metric("Equity", f"${stats['equity']:,.2f}") +col2.metric("Cash", f"${stats['cash']:,.2f}") +col3.metric("Trades", stats['total_trades']) +col4.metric("Signals", stats['total_signals']) # ============ Equity Curve ============ -st.header("📈 Equity Curve") +st.header("Equity Curve") -history = get_account_history(1000) +history = get_account_history(account_id, 1000) if not history.empty: - history['timestamp'] = pd.to_datetime(history['timestamp']) - fig = px.line(history, x='timestamp', y='equity', title='Account Equity Over Time') + history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601') + fig = px.line(history, x='timestamp', y='equity') fig.update_layout(xaxis_title="Time", yaxis_title="Equity ($)") st.plotly_chart(fig, use_container_width=True) else: st.info("No account history yet") -# ============ Positions ============ -st.header("💼 Current Positions") - -positions = get_positions() -if not positions.empty: - positions['market_value'] = positions['quantity'] * positions['current_price'] - positions['pnl'] = (positions['current_price'] - positions['avg_entry_price']) * positions['quantity'] - positions['pnl_pct'] = ((positions['current_price'] / positions['avg_entry_price']) - 1) * 100 - - st.dataframe( - positions[['symbol', 'quantity', 'avg_entry_price', 'current_price', 'market_value', 'pnl', 'pnl_pct']], - use_container_width=True, - column_config={ - 'avg_entry_price': st.column_config.NumberColumn('Avg Entry', format='$%.2f'), - 'current_price': st.column_config.NumberColumn('Current', format='$%.2f'), - 'market_value': st.column_config.NumberColumn('Value', format='$%.2f'), - 'pnl': st.column_config.NumberColumn('P&L', format='$%.2f'), - 'pnl_pct': st.column_config.NumberColumn('P&L %', format='%.2f%%'), - } - ) -else: - st.info("No open positions") - # ============ Strategy Performance ============ -st.header("🎯 Strategy Performance") +st.header("Strategy Performance") col1, col2 = st.columns(2) with col1: st.subheader("Trade Count by Strategy") - perf = get_strategy_performance() + perf = get_strategy_performance(account_id) if not perf.empty: - fig = px.bar(perf, x='strategy', y='trade_count', title='Trades per Strategy') + fig = px.bar(perf, x='strategy', y='trade_count') st.plotly_chart(fig, use_container_width=True) else: st.info("No trades yet") with col2: st.subheader("Signal Stats") - signal_stats = get_signal_stats() + signal_stats = get_signal_stats(account_id) if not signal_stats.empty: st.dataframe(signal_stats, use_container_width=True) else: st.info("No signals yet") # ============ Recent Trades ============ -st.header("📜 Recent Trades") +st.header("Recent Trades") -tab1, tab2 = st.tabs(["Broker Trades (Executed)", "Engine Trades (With Strategy)"]) - -with tab1: - broker_trades = get_broker_trades(50) - if not broker_trades.empty: - broker_trades['value'] = broker_trades['quantity'] * broker_trades['price'] - st.dataframe(broker_trades, use_container_width=True) - else: - st.info("No trades yet") - -with tab2: - engine_trades = get_engine_trades(50) - if not engine_trades.empty: - engine_trades['value'] = engine_trades['quantity'] * engine_trades['price'] - st.dataframe(engine_trades, use_container_width=True) - else: - st.info("No trades yet") +trades = get_trades(account_id=account_id, limit=50) +if not trades.empty: + trades['value'] = trades['quantity'] * trades['price'] + st.dataframe( + trades[['timestamp', 'symbol', 'side', 'quantity', 'price', 'value', 'strategy']], + use_container_width=True, hide_index=True, + ) +else: + st.info("No trades yet") # ============ Recent Signals ============ -st.header("📡 Recent Signals") +st.header("Recent Signals") -signals = get_signals(50) +signals = get_signals(account_id=account_id, limit=50) if not signals.empty: - st.dataframe(signals, use_container_width=True) - - # Signal distribution + st.dataframe( + signals[['timestamp', 'symbol', 'signal_type', 'strength', 'strategy']], + use_container_width=True, hide_index=True, + ) + col1, col2 = st.columns(2) with col1: fig = px.pie(signals, names='signal_type', title='Signal Distribution') st.plotly_chart(fig, use_container_width=True) with col2: - fig = px.histogram(signals, x='strength', nbins=20, title='Signal Strength Distribution') + fig = px.histogram(signals, x='strength', nbins=20, title='Signal Strength') st.plotly_chart(fig, use_container_width=True) else: st.info("No signals yet") diff --git a/streamlit/brokers/__init__.py b/streamlit/brokers/__init__.py new file mode 100644 index 0000000..2f6e9f6 --- /dev/null +++ b/streamlit/brokers/__init__.py @@ -0,0 +1,15 @@ +from .paper_alpaca import PaperAlpacaDefinition +from .alpaca_live import AlpacaLiveDefinition + +_ALL_BROKERS = [ + PaperAlpacaDefinition(), + AlpacaLiveDefinition(), +] + +BROKER_REGISTRY = {b.broker_type(): b for b in _ALL_BROKERS} + +def get_broker_types(): + return list(BROKER_REGISTRY.keys()) + +def get_broker_definition(broker_type: str): + return BROKER_REGISTRY.get(broker_type) diff --git a/streamlit/brokers/alpaca_live.py b/streamlit/brokers/alpaca_live.py new file mode 100644 index 0000000..c29a01b --- /dev/null +++ b/streamlit/brokers/alpaca_live.py @@ -0,0 +1,16 @@ +from .base import BrokerDefinition + + +class AlpacaLiveDefinition(BrokerDefinition): + def name(self) -> str: + return "Alpaca Live" + + def broker_type(self) -> str: + return "live" + + def description(self) -> str: + return "Live trading via Alpaca API. Uses real money." + + def render_config(self, st) -> dict: + st.warning("This will use real money. Ensure API keys are configured in .env") + return {} diff --git a/streamlit/brokers/base.py b/streamlit/brokers/base.py new file mode 100644 index 0000000..e95b7e3 --- /dev/null +++ b/streamlit/brokers/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + + +class BrokerDefinition(ABC): + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def broker_type(self) -> str: + pass + + @abstractmethod + def render_config(self, st) -> dict: + """Render Streamlit inputs and return config dict. Return None to skip.""" + pass + + @abstractmethod + def description(self) -> str: + pass diff --git a/streamlit/brokers/paper_alpaca.py b/streamlit/brokers/paper_alpaca.py new file mode 100644 index 0000000..9ccfd8e --- /dev/null +++ b/streamlit/brokers/paper_alpaca.py @@ -0,0 +1,18 @@ +from .base import BrokerDefinition + + +class PaperAlpacaDefinition(BrokerDefinition): + def name(self) -> str: + return "Paper (Alpaca Data)" + + def broker_type(self) -> str: + return "paper" + + def description(self) -> str: + return "Simulated trading with real Alpaca market data. No real money." + + def render_config(self, st) -> dict: + starting_cash = st.number_input( + "Starting Cash", min_value=1000.0, value=100000.0, step=10000.0, format="%.2f" + ) + return {"starting_cash": starting_cash} diff --git a/streamlit/pages/account.py b/streamlit/pages/account.py deleted file mode 100644 index 577719d..0000000 --- a/streamlit/pages/account.py +++ /dev/null @@ -1,84 +0,0 @@ -import streamlit as st -import sys -sys.path.append('..') -from utils.db import get_account_history -import plotly.graph_objects as go -import pandas as pd - -st.title("💰 Account History") - -account_df = get_account_history(limit=500) - -if not account_df.empty: - # Sort by timestamp - account_df = account_df.sort_values('timestamp') - account_df['timestamp'] = pd.to_datetime(account_df['timestamp']) - - # Current stats - latest = account_df.iloc[-1] - - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Current Equity", f"${latest['equity']:,.2f}") - - with col2: - st.metric("Cash", f"${latest['cash']:,.2f}") - - with col3: - st.metric("Buying Power", f"${latest['buying_power']:,.2f}") - - st.markdown("---") - - # Equity curve - st.subheader("Equity Curve") - - fig = go.Figure() - - fig.add_trace(go.Scatter( - x=account_df['timestamp'], - y=account_df['equity'], - mode='lines', - name='Equity', - line=dict(color='green', width=2) - )) - - fig.add_trace(go.Scatter( - x=account_df['timestamp'], - y=account_df['cash'], - mode='lines', - name='Cash', - line=dict(color='blue', width=2, dash='dash') - )) - - fig.update_layout( - xaxis_title="Date", - yaxis_title="Value ($)", - hovermode='x unified' - ) - - st.plotly_chart(fig, use_container_width=True) - - st.markdown("---") - - # Stats - st.subheader("Performance Statistics") - - initial_equity = account_df.iloc[0]['equity'] - final_equity = latest['equity'] - total_return = ((final_equity - initial_equity) / initial_equity) * 100 - - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Initial Equity", f"${initial_equity:,.2f}") - - with col2: - st.metric("Final Equity", f"${final_equity:,.2f}") - - with col3: - st.metric("Total Return", f"{total_return:.2f}%") - -else: - st.info("No account history yet") - \ No newline at end of file diff --git a/streamlit/pages/account_history.py b/streamlit/pages/account_history.py new file mode 100644 index 0000000..50b4ef5 --- /dev/null +++ b/streamlit/pages/account_history.py @@ -0,0 +1,101 @@ +import streamlit as st +import pandas as pd +import sys +sys.path.append('..') +from utils.db import get_account_history, get_active_account_ids +import plotly.graph_objects as go + +st.set_page_config(page_title="Account History", layout="wide") +st.title("Account History") + +account_ids = get_active_account_ids() + +if not account_ids: + st.info("No active accounts found") + st.stop() + +view_mode = st.radio("View", ["Single Account", "All Accounts"], horizontal=True) + +if view_mode == "Single Account": + account_id = st.selectbox("Account", account_ids) + account_df = get_account_history(account_id=account_id, limit=500) + + if account_df.empty: + st.info(f"No history for {account_id}") + st.stop() + + account_df = account_df.sort_values('timestamp') + account_df['timestamp'] = pd.to_datetime(account_df['timestamp'], format='ISO8601') + latest = account_df.iloc[-1] + + col1, col2, col3 = st.columns(3) + col1.metric("Current Equity", f"${latest['equity']:,.2f}") + col2.metric("Cash", f"${latest['cash']:,.2f}") + col3.metric("Buying Power", f"${latest['buying_power']:,.2f}") + + st.markdown("---") + st.subheader("Equity Curve") + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=account_df['timestamp'], y=account_df['equity'], + mode='lines', name='Equity', line=dict(color='green', width=2) + )) + fig.add_trace(go.Scatter( + x=account_df['timestamp'], y=account_df['cash'], + mode='lines', name='Cash', line=dict(color='blue', width=2, dash='dash') + )) + fig.update_layout(xaxis_title="Date", yaxis_title="Value ($)", hovermode='x unified') + st.plotly_chart(fig, use_container_width=True) + + st.markdown("---") + st.subheader("Performance") + + initial_equity = account_df.iloc[0]['equity'] + final_equity = latest['equity'] + total_return = ((final_equity - initial_equity) / initial_equity) * 100 if initial_equity > 0 else 0 + + col1, col2, col3 = st.columns(3) + col1.metric("Initial Equity", f"${initial_equity:,.2f}") + col2.metric("Final Equity", f"${final_equity:,.2f}") + col3.metric("Total Return", f"{total_return:.2f}%") + +else: + st.subheader("All Accounts") + + rows = [] + for aid in account_ids: + df = get_account_history(account_id=aid, limit=500) + if df.empty: + continue + df = df.sort_values('timestamp') + initial = df.iloc[0]['equity'] + final = df.iloc[-1]['equity'] + ret = ((final - initial) / initial) * 100 if initial > 0 else 0 + rows.append({ + "Account": aid, + "Equity": f"${final:,.2f}", + "Cash": f"${df.iloc[-1]['cash']:,.2f}", + "Return": f"{ret:.2f}%", + }) + + if rows: + st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) + + st.markdown("---") + st.subheader("Equity Curves") + + fig = go.Figure() + for aid in account_ids: + df = get_account_history(account_id=aid, limit=500) + if df.empty: + continue + df = df.sort_values('timestamp') + df['timestamp'] = pd.to_datetime(df['timestamp'], format='ISO8601') + fig.add_trace(go.Scatter( + x=df['timestamp'], y=df['equity'], + mode='lines', name=aid, line=dict(width=2) + )) + + fig.update_layout(xaxis_title="Date", yaxis_title="Equity ($)", hovermode='x unified') + st.plotly_chart(fig, use_container_width=True) diff --git a/streamlit/pages/accounts.py b/streamlit/pages/accounts.py new file mode 100644 index 0000000..4d18525 --- /dev/null +++ b/streamlit/pages/accounts.py @@ -0,0 +1,186 @@ +import streamlit as st +import json +import sys +sys.path.append('..') +from utils.db import ( + create_account, delete_account, get_active_account_ids, get_all_account_ids, get_strategies, + get_strategies_for_account, assign_strategy_to_account, toggle_strategy_enabled, + unassign_strategy_from_account, update_assignment_symbols, + get_account_stats, get_all_assignments, +) + +st.set_page_config(page_title="Accounts", layout="wide") +st.title("Accounts") + +if st.button("Refresh"): + st.rerun() + + +def get_all_known_symbols(): + df = get_all_assignments() + if df.empty: + return [] + symbols = set() + for _, row in df.iterrows(): + try: + syms = json.loads(row['symbols_json']) + symbols.update(syms) + except (json.JSONDecodeError, TypeError): + pass + return sorted(symbols) + + +def symbol_picker(label, current_symbols, key_prefix): + known = get_all_known_symbols() + all_options = sorted(set(known + current_symbols)) + + selected = st.multiselect( + label, + options=all_options, + default=current_symbols, + key=f"{key_prefix}_multi", + placeholder="Select or type symbols...", + ) + + new_input = st.text_input( + "Add new symbols (comma-separated)", + key=f"{key_prefix}_new", + placeholder="e.g. NVDA, AMD", + ) + + if new_input.strip(): + for s in new_input.split(","): + s = s.strip().upper() + if s and s not in selected: + selected.append(s) + + return selected + + +# ============ Account Overview ============ + +st.subheader("Active Accounts") + +account_ids = get_all_account_ids() +active_ids = get_active_account_ids() + +if account_ids: + for aid in account_ids: + is_active = aid in active_ids + status = "Active" if is_active else "Inactive" + + with st.expander(f"{aid} — {status}"): + stats = get_account_stats(aid) + if stats: + col1, col2, col3 = st.columns(3) + col1.metric("Equity", f"${stats['equity']:,.2f}") + col2.metric("Trades", stats['total_trades']) + col3.metric("Signals", stats['total_signals']) + + strategies_df = get_strategies_for_account(aid) + if not strategies_df.empty: + for _, row in strategies_df.iterrows(): + symbols = json.loads(row['symbols_json']) if row['symbols_json'] else [] + is_enabled = bool(row['enabled']) + + col_name, col_toggle = st.columns([4, 1]) + with col_name: + st.markdown(f"**{row['name']}** ({row['strategy_type']})") + with col_toggle: + new_enabled = st.toggle( + "Enabled", + value=is_enabled, + key=f"toggle_{aid}_{row['id']}", + ) + if new_enabled != is_enabled: + toggle_strategy_enabled(aid, row['id'], new_enabled) + st.rerun() + + updated_symbols = symbol_picker( + "Symbols", + symbols, + key_prefix=f"edit_{aid}_{row['id']}", + ) + + col1, col2 = st.columns([1, 1]) + with col1: + if updated_symbols != symbols: + if st.button("Save Changes", key=f"save_{aid}_{row['id']}"): + update_assignment_symbols(aid, row['id'], updated_symbols) + st.success(f"Updated symbols to {updated_symbols}") + st.rerun() + with col2: + if st.button("Remove Strategy", key=f"rm_{aid}_{row['id']}", type="secondary"): + unassign_strategy_from_account(aid, row['id']) + st.rerun() + + st.divider() + + if st.button("Delete Account", key=f"del_acct_{aid}", type="primary"): + delete_account(aid) + st.success(f"Deleted account {aid}") + st.rerun() + else: + st.info("No strategies assigned") +else: + st.info("No accounts yet. Assign a strategy below to create one.") + +st.markdown("---") + +# ============ Create Account ============ + +st.subheader("Create Account") + +col1, col2 = st.columns(2) +with col1: + new_account_id = st.text_input("Account ID", placeholder="e.g. paper_aggressive", key="new_acct_id") +with col2: + starting_cash = st.number_input("Starting Cash", min_value=1000.0, value=100000.0, step=10000.0, format="%.2f", key="new_acct_cash") + +if st.button("Create Account"): + if not new_account_id.strip(): + st.error("Account ID is required") + elif new_account_id.strip() in get_all_account_ids(): + st.error("Account already exists") + else: + create_account(new_account_id.strip(), starting_cash) + st.success(f"Created account {new_account_id} with ${starting_cash:,.2f}") + st.rerun() + +st.markdown("---") + +# ============ Assign Strategy to Account ============ + +st.subheader("Assign Strategy to Account") + +all_strategies = get_strategies() + +if all_strategies.empty: + st.info("No strategies created yet. Go to the Strategies page to create one.") + st.stop() + +existing_accounts = get_all_account_ids() + +if not existing_accounts: + st.info("Create an account first.") + st.stop() + +col1, col2 = st.columns(2) + +with col1: + account_id = st.selectbox("Account", existing_accounts, key="assign_acct") + +with col2: + strategy_options = {row['name']: row['id'] for _, row in all_strategies.iterrows()} + selected_name = st.selectbox("Strategy", list(strategy_options.keys())) + +symbols = symbol_picker("Symbols", [], key_prefix="assign_new") + +if st.button("Assign"): + if not symbols: + st.error("At least one symbol is required") + else: + strategy_id = strategy_options[selected_name] + assign_strategy_to_account(account_id, strategy_id, symbols) + st.success(f"Assigned {selected_name} to {account_id} with {symbols}") + st.rerun() \ No newline at end of file diff --git a/streamlit/pages/overview.py b/streamlit/pages/overview.py index 8ed4886..c63bfdf 100644 --- a/streamlit/pages/overview.py +++ b/streamlit/pages/overview.py @@ -2,82 +2,53 @@ import pandas as pd import sys sys.path.append('..') -from utils.db import ( - get_account_stats, - get_engine_trades, - get_signals, - get_positions, - get_account_history, -) -import plotly.express as px +from utils.db import get_account_stats, get_trades, get_signals, get_account_history, get_active_account_ids import plotly.graph_objects as go st.set_page_config(page_title="Cubert Overview", layout="wide") -st.title("📊 Overview") +st.title("Overview") -# Auto-refresh button -if st.button("🔄 Refresh"): +if st.button("Refresh"): st.rerun() -# Get stats -stats = get_account_stats() +account_ids = get_active_account_ids() -# Metrics row 1 -col1, col2, col3, col4 = st.columns(4) +if not account_ids: + st.info("No active accounts. Create an account and assign strategies to get started.") + st.stop() -with col1: - st.metric("💰 Equity", f"${stats['equity']:,.2f}") +account_id = st.selectbox("Account", account_ids) -with col2: - st.metric("💵 Cash", f"${stats['cash']:,.2f}") +stats = get_account_stats(account_id) -with col3: - pnl = stats['unrealized_pnl'] - st.metric("📈 Unrealized P&L", f"${pnl:,.2f}", delta=f"{pnl:,.2f}") +if not stats: + st.info(f"No data for {account_id} yet.") + st.stop() -with col4: - st.metric("💼 Positions", stats['positions_count']) - -# Metrics row 2 col1, col2, col3, col4 = st.columns(4) - -with col1: - st.metric("📊 Total Trades", stats['total_trades']) - -with col2: - st.metric("📡 Total Signals", stats['total_signals']) - -with col3: - st.metric("💳 Buying Power", f"${stats['buying_power']:,.2f}") - -with col4: - # Calculate win rate if we have trades - pass # Placeholder for future metrics +col1.metric("Equity", f"${stats['equity']:,.2f}") +col2.metric("Cash", f"${stats['cash']:,.2f}") +col3.metric("Trades", stats['total_trades']) +col4.metric("Signals", stats['total_signals']) st.markdown("---") -# Equity curve -st.subheader("📈 Equity Curve") -history = get_account_history(500) +st.subheader("Equity Curve") +history = get_account_history(account_id, 500) if not history.empty: - history['timestamp'] = pd.to_datetime(history['timestamp']) + history['timestamp'] = pd.to_datetime(history['timestamp'], format='ISO8601') history = history.sort_values('timestamp') - + fig = go.Figure() fig.add_trace(go.Scatter( - x=history['timestamp'], - y=history['equity'], - mode='lines', - name='Equity', + x=history['timestamp'], y=history['equity'], + mode='lines', name='Equity', line=dict(color='#00d4aa', width=2), - fill='tozeroy', - fillcolor='rgba(0, 212, 170, 0.1)' + fill='tozeroy', fillcolor='rgba(0, 212, 170, 0.1)' )) fig.update_layout( - xaxis_title="Time", - yaxis_title="Equity ($)", - hovermode='x unified', - height=300, + xaxis_title="Time", yaxis_title="Equity ($)", + hovermode='x unified', height=300, margin=dict(l=0, r=0, t=10, b=0), ) st.plotly_chart(fig, use_container_width=True) @@ -86,48 +57,18 @@ st.markdown("---") -# Current positions -st.subheader("💼 Current Positions") -positions = get_positions() -if not positions.empty: - positions['market_value'] = positions['quantity'] * positions['current_price'] - positions['pnl'] = (positions['current_price'] - positions['avg_entry_price']) * positions['quantity'] - positions['pnl_pct'] = ((positions['current_price'] / positions['avg_entry_price']) - 1) * 100 - - st.dataframe( - positions[['symbol', 'quantity', 'avg_entry_price', 'current_price', 'market_value', 'pnl', 'pnl_pct']], - hide_index=True, - use_container_width=True, - column_config={ - 'symbol': 'Symbol', - 'quantity': st.column_config.NumberColumn('Qty', format='%.2f'), - 'avg_entry_price': st.column_config.NumberColumn('Avg Entry', format='$%.2f'), - 'current_price': st.column_config.NumberColumn('Current', format='$%.2f'), - 'market_value': st.column_config.NumberColumn('Value', format='$%.2f'), - 'pnl': st.column_config.NumberColumn('P&L', format='$%.2f'), - 'pnl_pct': st.column_config.NumberColumn('P&L %', format='%.2f%%'), - } - ) -else: - st.info("No open positions") - -st.markdown("---") - -# Recent activity col1, col2 = st.columns(2) with col1: - st.subheader("🔄 Recent Trades") - trades_df = get_engine_trades(limit=10) + st.subheader("Recent Trades") + trades_df = get_trades(account_id=account_id, limit=10) if not trades_df.empty: trades_df['value'] = trades_df['quantity'] * trades_df['price'] st.dataframe( trades_df[['symbol', 'side', 'quantity', 'price', 'value', 'timestamp']], - hide_index=True, - use_container_width=True, + hide_index=True, use_container_width=True, column_config={ - 'symbol': 'Symbol', - 'side': 'Side', + 'symbol': 'Symbol', 'side': 'Side', 'quantity': st.column_config.NumberColumn('Qty', format='%.2f'), 'price': st.column_config.NumberColumn('Price', format='$%.2f'), 'value': st.column_config.NumberColumn('Value', format='$%.2f'), @@ -138,19 +79,16 @@ st.info("No trades yet") with col2: - st.subheader("📡 Recent Signals") - signals_df = get_signals(limit=10) + st.subheader("Recent Signals") + signals_df = get_signals(account_id=account_id, limit=10) if not signals_df.empty: st.dataframe( signals_df[['symbol', 'signal_type', 'strength', 'strategy', 'timestamp']], - hide_index=True, - use_container_width=True, + hide_index=True, use_container_width=True, column_config={ - 'symbol': 'Symbol', - 'signal_type': 'Signal', + 'symbol': 'Symbol', 'signal_type': 'Signal', 'strength': st.column_config.NumberColumn('Strength', format='%.2f'), - 'strategy': 'Strategy', - 'timestamp': 'Time', + 'strategy': 'Strategy', 'timestamp': 'Time', } ) else: diff --git a/streamlit/pages/positions.py b/streamlit/pages/positions.py index 086d5bb..2c17f86 100644 --- a/streamlit/pages/positions.py +++ b/streamlit/pages/positions.py @@ -1,12 +1,19 @@ import streamlit as st import sys sys.path.append('..') -from utils.db import get_positions +from utils.db import get_positions, get_active_account_ids import plotly.graph_objects as go -st.title("📈 Current Positions") +st.title("Positions") -positions_df = get_positions() +account_ids = get_active_account_ids() + +if not account_ids: + st.info("No active accounts") + st.stop() + +account_id = st.selectbox("Account", account_ids) +positions_df = get_positions(account_id) if not positions_df.empty: # Calculate metrics diff --git a/streamlit/pages/signals.py b/streamlit/pages/signals.py index b8868cb..f57ff24 100644 --- a/streamlit/pages/signals.py +++ b/streamlit/pages/signals.py @@ -1,66 +1,49 @@ import streamlit as st +import pandas as pd import sys sys.path.append('..') -from utils.db import get_signals +from utils.db import get_signals, get_active_account_ids import plotly.express as px -import pandas as pd -st.title("📡 Signals") +st.set_page_config(page_title="Signals", layout="wide") +st.title("Signals") -# Filters -col1, col2, col3 = st.columns(3) +account_ids = get_active_account_ids() +col1, col2, col3 = st.columns(3) with col1: - symbol_filter = st.selectbox("Symbol", ["All", "AAPL", "MSFT", "TSLA"]) - + view = st.selectbox("Account", ["All"] + account_ids) with col2: signal_type_filter = st.selectbox("Signal Type", ["All", "buy", "sell"]) - with col3: - limit = st.number_input("Records to show", min_value=10, max_value=1000, value=100) + limit = st.number_input("Limit", min_value=10, max_value=1000, value=100) -# Get data -signals_df = get_signals(limit=limit) +account_filter = None if view == "All" else view +signals_df = get_signals(account_id=account_filter, limit=limit) -if not signals_df.empty: - # Apply filters - if symbol_filter != "All": - signals_df = signals_df[signals_df['symbol'] == symbol_filter] - - if signal_type_filter != "All": - signals_df = signals_df[signals_df['signal_type'] == signal_type_filter] - - # Stats - col1, col2, col3 = st.columns(3) - - with col1: - st.metric("Total Signals", len(signals_df)) - - with col2: - buy_count = len(signals_df[signals_df['signal_type'] == 'buy']) - st.metric("Buy Signals", buy_count) - - with col3: - sell_count = len(signals_df[signals_df['signal_type'] == 'sell']) - st.metric("Sell Signals", sell_count) - - st.markdown("---") - - # Signal distribution - st.subheader("Signal Distribution by Symbol") - signal_counts = signals_df.groupby(['symbol', 'signal_type']).size().reset_index(name='count') - fig = px.bar(signal_counts, x='symbol', y='count', color='signal_type', barmode='group') - st.plotly_chart(fig, use_container_width=True) - - st.markdown("---") - - # Data table - st.subheader("Signal History") - st.dataframe( - signals_df[['timestamp', 'symbol', 'signal_type', 'strength', 'strategy']], - hide_index=True, - use_container_width=True - ) -else: +if signals_df.empty: st.info("No signals recorded yet") - \ No newline at end of file + st.stop() + +if signal_type_filter != "All": + signals_df = signals_df[signals_df['signal_type'] == signal_type_filter] + +col1, col2, col3 = st.columns(3) +col1.metric("Total Signals", len(signals_df)) +col2.metric("Buy Signals", len(signals_df[signals_df['signal_type'] == 'buy'])) +col3.metric("Sell Signals", len(signals_df[signals_df['signal_type'] == 'sell'])) + +st.markdown("---") + +st.subheader("Signal Distribution by Symbol") +signal_counts = signals_df.groupby(['symbol', 'signal_type']).size().reset_index(name='count') +fig = px.bar(signal_counts, x='symbol', y='count', color='signal_type', barmode='group') +st.plotly_chart(fig, use_container_width=True) + +st.markdown("---") + +st.subheader("Signal History") +display_cols = ['timestamp', 'symbol', 'signal_type', 'strength', 'strategy'] +if view == "All": + display_cols.insert(1, 'account_id') +st.dataframe(signals_df[display_cols], hide_index=True, use_container_width=True) diff --git a/streamlit/pages/strategies.py b/streamlit/pages/strategies.py new file mode 100644 index 0000000..64f322d --- /dev/null +++ b/streamlit/pages/strategies.py @@ -0,0 +1,129 @@ +import streamlit as st +import json +import sys +sys.path.append('..') +from utils.db import get_strategies, insert_strategy, update_strategy, delete_strategy +from strategies import STRATEGY_REGISTRY + +st.set_page_config(page_title="Strategies", layout="wide") +st.title("Strategies") + +if st.button("Refresh"): + st.rerun() + +# ============ Existing Strategies ============ + +st.subheader("Existing Strategies") + +strategies_df = get_strategies() + +if not strategies_df.empty: + for _, row in strategies_df.iterrows(): + params = json.loads(row['params_json']) if row['params_json'] else {} + param_type = params.get("type", row['strategy_type']) + definition = STRATEGY_REGISTRY.get(param_type) + + with st.expander(f"{row['name']} — {row['strategy_type']}"): + st.caption(f"ID: {row['id']} | Created: {row['created_at']}") + + new_name = st.text_input("Name", value=row['name'], key=f"name_{row['id']}") + + if definition: + st.write(f"**{definition.name()} Parameters**") + descs = definition.param_descriptions() + + edited_params = {} + col1, col2 = st.columns(2) + + param_keys = [k for k in params if k != "type"] + half = (len(param_keys) + 1) // 2 + + for i, key in enumerate(param_keys): + col = col1 if i < half else col2 + value = params[key] + label = key.replace('_', ' ').title() + + with col: + if isinstance(value, float): + edited_params[key] = st.number_input( + label, value=value, step=0.005, format="%.3f", + key=f"param_{row['id']}_{key}" + ) + elif isinstance(value, int): + edited_params[key] = st.number_input( + label, value=value, min_value=1, + key=f"param_{row['id']}_{key}" + ) + elif key.endswith("_lookback") or key.endswith("_timeframe"): + options = ["1Min", "5Min", "15Min", "30Min", "1Hour", "1Day"] + idx = options.index(value) if value in options else 0 + edited_params[key] = st.selectbox( + label, options, index=idx, + key=f"param_{row['id']}_{key}" + ) + else: + edited_params[key] = st.text_input( + label, value=str(value), + key=f"param_{row['id']}_{key}" + ) + + desc = descs.get(key, "") + if desc: + st.caption(desc) + + # Check for changes + name_changed = new_name != row['name'] + params_changed = edited_params != {k: v for k, v in params.items() if k != "type"} + + col1, col2 = st.columns([1, 1]) + with col1: + if name_changed or params_changed: + if st.button("Save Changes", key=f"save_{row['id']}"): + new_params = {"type": param_type, **edited_params} + update_strategy( + row['id'], + new_name, + row['strategy_type'], + new_params, + ) + st.success("Saved") + st.rerun() + else: + st.button("Save Changes", key=f"save_{row['id']}", disabled=True) + + with col2: + if st.button("Delete", key=f"del_{row['id']}", type="primary"): + delete_strategy(row['id']) + st.success(f"Deleted {row['name']}") + st.rerun() + else: + st.json(params) + if st.button("Delete", key=f"del_{row['id']}", type="primary"): + delete_strategy(row['id']) + st.rerun() +else: + st.info("No strategies created yet.") + +st.markdown("---") + +# ============ Create Strategy ============ + +st.subheader("Create Strategy") + +strategy_types = list(STRATEGY_REGISTRY.keys()) +selected_type = st.selectbox("Strategy Type", strategy_types) +definition = STRATEGY_REGISTRY[selected_type] + +name = st.text_input("Strategy Name", placeholder=f"e.g. {definition.name().lower()}_v1") + +st.write(f"**{definition.name()} Parameters**") +params = definition.render_params(st) + +if st.button("Create Strategy"): + if not name.strip(): + st.error("Strategy name is required") + else: + params_json = definition.to_params_json(params) + strategy_id = insert_strategy(name.strip(), selected_type, params_json) + st.success(f"Created strategy '{name}' (ID: {strategy_id})") + st.rerun() diff --git a/streamlit/pages/trades.py b/streamlit/pages/trades.py index 69462f3..aba20cb 100644 --- a/streamlit/pages/trades.py +++ b/streamlit/pages/trades.py @@ -1,54 +1,53 @@ import streamlit as st +import pandas as pd import sys sys.path.append('..') -from utils.db import get_engine_trades +from utils.db import get_trades, get_active_account_ids import plotly.express as px -import pandas as pd -st.title("💼 Trades") - -# Get data -trades_df = get_engine_trades(limit=200) - -if not trades_df.empty: - # Calculate P&L - trades_df['value'] = trades_df['quantity'] * trades_df['price'] - - # Stats - col1, col2, col3, col4 = st.columns(4) - - with col1: - st.metric("Total Trades", len(trades_df)) - - with col2: - buy_value = trades_df[trades_df['side'] == 'buy']['value'].sum() - st.metric("Total Bought", f"${buy_value:,.2f}") - - with col3: - sell_value = trades_df[trades_df['side'] == 'sell']['value'].sum() - st.metric("Total Sold", f"${sell_value:,.2f}") - - with col4: - pnl = sell_value - buy_value - st.metric("Net P&L", f"${pnl:,.2f}", delta=f"{pnl:,.2f}") - - st.markdown("---") - - # Trade volume by symbol - st.subheader("Trade Volume by Symbol") - volume_by_symbol = trades_df.groupby('symbol')['value'].sum().reset_index() - fig = px.pie(volume_by_symbol, values='value', names='symbol', title='Trade Volume Distribution') - st.plotly_chart(fig, use_container_width=True) - - st.markdown("---") - - # Trade history - st.subheader("Trade History") - st.dataframe( - trades_df[['timestamp', 'symbol', 'side', 'quantity', 'price', 'strategy']], - hide_index=True, - use_container_width=True - ) -else: +st.set_page_config(page_title="Trades", layout="wide") +st.title("Trades") + +account_ids = get_active_account_ids() + +col1, col2 = st.columns(2) +with col1: + view = st.selectbox("Account", ["All"] + account_ids) +with col2: + limit = st.number_input("Limit", min_value=10, max_value=1000, value=200) + +account_filter = None if view == "All" else view +trades_df = get_trades(account_id=account_filter, limit=limit) + +if trades_df.empty: st.info("No trades yet") - \ No newline at end of file + st.stop() + +trades_df['value'] = trades_df['quantity'] * trades_df['price'] + +col1, col2, col3, col4 = st.columns(4) +col1.metric("Total Trades", len(trades_df)) + +buy_value = trades_df[trades_df['side'] == 'buy']['value'].sum() +col2.metric("Total Bought", f"${buy_value:,.2f}") + +sell_value = trades_df[trades_df['side'] == 'sell']['value'].sum() +col3.metric("Total Sold", f"${sell_value:,.2f}") + +pnl = sell_value - buy_value +col4.metric("Net P&L", f"${pnl:,.2f}") + +st.markdown("---") + +st.subheader("Trade Volume by Symbol") +volume_by_symbol = trades_df.groupby('symbol')['value'].sum().reset_index() +fig = px.pie(volume_by_symbol, values='value', names='symbol') +st.plotly_chart(fig, use_container_width=True) + +st.markdown("---") + +st.subheader("Trade History") +display_cols = ['timestamp', 'symbol', 'side', 'quantity', 'price', 'strategy'] +if view == "All": + display_cols.insert(1, 'account_id') +st.dataframe(trades_df[display_cols], hide_index=True, use_container_width=True) diff --git a/streamlit/strategies/__init__.py b/streamlit/strategies/__init__.py new file mode 100644 index 0000000..68801a4 --- /dev/null +++ b/streamlit/strategies/__init__.py @@ -0,0 +1,15 @@ +from .momentum import MomentumDefinition +from .mean_reversion import MeanReversionDefinition + +_ALL_STRATEGIES = [ + MomentumDefinition(), + MeanReversionDefinition(), +] + +STRATEGY_REGISTRY = {s.strategy_type(): s for s in _ALL_STRATEGIES} + +def get_strategy_types(): + return list(STRATEGY_REGISTRY.keys()) + +def get_strategy_definition(strategy_type: str): + return STRATEGY_REGISTRY.get(strategy_type) diff --git a/streamlit/strategies/base.py b/streamlit/strategies/base.py new file mode 100644 index 0000000..a365080 --- /dev/null +++ b/streamlit/strategies/base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod + + +class StrategyDefinition(ABC): + @abstractmethod + def name(self) -> str: + pass + + @abstractmethod + def strategy_type(self) -> str: + pass + + @abstractmethod + def default_params(self) -> dict: + pass + + @abstractmethod + def param_descriptions(self) -> dict: + pass + + @abstractmethod + def render_params(self, st) -> dict: + pass + + def to_params_json(self, params: dict) -> dict: + return {"type": self.strategy_type(), **params} diff --git a/streamlit/strategies/mean_reversion.py b/streamlit/strategies/mean_reversion.py new file mode 100644 index 0000000..bf93ca4 --- /dev/null +++ b/streamlit/strategies/mean_reversion.py @@ -0,0 +1,52 @@ +from .base import StrategyDefinition + + +class MeanReversionDefinition(StrategyDefinition): + def name(self) -> str: + return "Mean Reversion" + + def strategy_type(self) -> str: + return "MeanReversion" + + def default_params(self) -> dict: + return { + "window": 20, + "std_devs": 2.0, + "startup_lookback": "1Hour", + "startup_bar_limit": 50, + } + + def param_descriptions(self) -> dict: + return { + "window": "Rolling window size for computing the moving average.", + "std_devs": "Number of standard deviations from the mean to trigger a signal.", + "startup_lookback": "Bar timeframe used during historical warmup before live trading begins.", + "startup_bar_limit": "Number of historical bars to fetch during warmup.", + } + + def render_params(self, st) -> dict: + descs = self.param_descriptions() + + col1, col2 = st.columns(2) + with col1: + window = st.number_input("Window Size", min_value=5, value=20) + st.caption(descs["window"]) + + startup_lookback = st.selectbox( + "Startup Timeframe", ["1Min", "5Min", "15Min", "30Min", "1Hour", "1Day"], index=4 + ) + st.caption(descs["startup_lookback"]) + + with col2: + std_devs = st.number_input("Std Deviations", min_value=0.5, value=2.0, step=0.1, format="%.1f") + st.caption(descs["std_devs"]) + + startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=50) + st.caption(descs["startup_bar_limit"]) + + return { + "window": window, + "std_devs": std_devs, + "startup_lookback": startup_lookback, + "startup_bar_limit": startup_bar_limit, + } diff --git a/streamlit/strategies/momentum.py b/streamlit/strategies/momentum.py new file mode 100644 index 0000000..e8b3b2b --- /dev/null +++ b/streamlit/strategies/momentum.py @@ -0,0 +1,52 @@ +from .base import StrategyDefinition + + +class MomentumDefinition(StrategyDefinition): + def name(self) -> str: + return "Momentum" + + def strategy_type(self) -> str: + return "Momentum" + + def default_params(self) -> dict: + return { + "lookback_period": 20, + "threshold": 0.02, + "startup_lookback": "1Hour", + "startup_bar_limit": 50, + } + + def param_descriptions(self) -> dict: + return { + "lookback_period": "Number of bars to compare current price against for momentum calculation.", + "threshold": "Minimum momentum percentage to trigger a signal. Higher = fewer trades.", + "startup_lookback": "Bar timeframe used during historical warmup before live trading begins.", + "startup_bar_limit": "Number of historical bars to fetch during warmup.", + } + + def render_params(self, st) -> dict: + descs = self.param_descriptions() + + col1, col2 = st.columns(2) + with col1: + lookback = st.number_input("Lookback Period", min_value=1, value=20) + st.caption(descs["lookback_period"]) + + startup_lookback = st.selectbox( + "Startup Timeframe", ["1Min", "5Min", "15Min", "30Min", "1Hour", "1Day"], index=4 + ) + st.caption(descs["startup_lookback"]) + + with col2: + threshold = st.number_input("Threshold", min_value=0.001, value=0.02, step=0.005, format="%.3f") + st.caption(descs["threshold"]) + + startup_bar_limit = st.number_input("Startup Bar Limit", min_value=10, value=50) + st.caption(descs["startup_bar_limit"]) + + return { + "lookback_period": lookback, + "threshold": threshold, + "startup_lookback": startup_lookback, + "startup_bar_limit": startup_bar_limit, + } diff --git a/streamlit/utils/db.py b/streamlit/utils/db.py index 86809f5..0a75214 100644 --- a/streamlit/utils/db.py +++ b/streamlit/utils/db.py @@ -1,182 +1,319 @@ import sqlite3 +import json +import uuid +from datetime import datetime import pandas as pd -BROKER_DB = "broker.db" ENGINE_DB = "cubert.db" -def get_broker_connection(): - return sqlite3.connect(BROKER_DB) - - def get_engine_connection(): return sqlite3.connect(ENGINE_DB) -# ============ Broker DB (source of truth) ============ +# ============ Accounts ============ -def get_account(): - """Get current account state from broker""" - conn = get_broker_connection() - query = "SELECT cash FROM account LIMIT 1" - df = pd.read_sql_query(query, conn) - - positions = get_positions() - positions_value = (positions['quantity'] * positions['current_price']).sum() if not positions.empty else 0 - +def get_active_account_ids(): + conn = get_engine_connection() + df = pd.read_sql_query( + "SELECT DISTINCT account_id FROM account_strategies WHERE enabled = TRUE", + conn + ) + conn.close() + return df['account_id'].tolist() + + +def get_all_account_ids(): + conn = get_engine_connection() + df = pd.read_sql_query("SELECT id FROM account", conn) + conn.close() + return df['id'].tolist() + + +def create_account(account_id, starting_cash=100000.0): + conn = get_engine_connection() + now = datetime.utcnow().isoformat() + + conn.execute( + """INSERT OR IGNORE INTO account (id, cash, created_at, updated_at) + VALUES (?, ?, ?, ?)""", + (account_id, starting_cash, now, now) + ) + + conn.execute( + """INSERT INTO account_snapshots (id, account_id, equity, cash, buying_power, timestamp) + VALUES (?, ?, ?, ?, ?, ?)""", + (str(uuid.uuid4()), account_id, starting_cash, starting_cash, starting_cash, now) + ) + + conn.commit() + conn.close() + +def delete_account(account_id): + conn = get_engine_connection() + conn.execute("DELETE FROM account_strategies WHERE account_id = ?", (account_id,)) + conn.execute("DELETE FROM account WHERE id = ?", (account_id,)) + conn.execute("DELETE FROM signals WHERE account_id = ?", (account_id,)) + conn.execute("DELETE FROM trades WHERE account_id = ?", (account_id,)) + conn.execute("DELETE FROM account_snapshots WHERE account_id = ?", (account_id,)) + conn.execute("DELETE FROM positions WHERE account_id = ?", (account_id,)) + conn.commit() conn.close() - - cash = df['cash'].iloc[0] if not df.empty else 0 - return { - 'equity': cash + positions_value, - 'cash': cash, - 'buying_power': cash, - } +# ============ Account History ============ -def get_positions(): - """Get current positions from broker""" - conn = get_broker_connection() - query = "SELECT * FROM positions WHERE quantity != 0" - df = pd.read_sql_query(query, conn) +def get_account_history(account_id, limit=1000): + conn = get_engine_connection() + df = pd.read_sql_query( + "SELECT * FROM account_snapshots WHERE account_id = ? ORDER BY timestamp ASC LIMIT ?", + conn, + params=(account_id, limit) + ) conn.close() return df -def get_broker_trades(limit=100): - """Get executed trades from broker (source of truth)""" - conn = get_broker_connection() - query = """ - SELECT * FROM trades - ORDER BY timestamp DESC - LIMIT ? - """ - df = pd.read_sql_query(query, conn, params=(limit,)) +def get_latest_snapshot(account_id): + conn = get_engine_connection() + df = pd.read_sql_query( + "SELECT * FROM account_snapshots WHERE account_id = ? ORDER BY timestamp DESC LIMIT 1", + conn, + params=(account_id,) + ) + conn.close() + if df.empty: + return None + return df.iloc[0].to_dict() + + +# ============ Trades ============ + +def get_trades(account_id=None, limit=100): + conn = get_engine_connection() + if account_id: + df = pd.read_sql_query( + "SELECT * FROM trades WHERE account_id = ? ORDER BY timestamp DESC LIMIT ?", + conn, + params=(account_id, limit) + ) + else: + df = pd.read_sql_query( + "SELECT * FROM trades ORDER BY timestamp DESC LIMIT ?", + conn, + params=(limit,) + ) conn.close() return df -def get_account_history(limit=1000): - """Get account equity history from broker""" - conn = get_broker_connection() - query = """ - SELECT * FROM account_snapshots - ORDER BY timestamp ASC - LIMIT ? - """ - df = pd.read_sql_query(query, conn, params=(limit,)) +# ============ Signals ============ + +def get_signals(account_id=None, limit=100): + conn = get_engine_connection() + if account_id: + df = pd.read_sql_query( + "SELECT * FROM signals WHERE account_id = ? ORDER BY timestamp DESC LIMIT ?", + conn, + params=(account_id, limit) + ) + else: + df = pd.read_sql_query( + "SELECT * FROM signals ORDER BY timestamp DESC LIMIT ?", + conn, + params=(limit,) + ) conn.close() return df -# ============ Engine DB (logging/analytics) ============ +# ============ Strategies ============ -def get_signals(limit=100): - """Get signals from engine""" +def get_strategies(): conn = get_engine_connection() - query = """ - SELECT * FROM signals - ORDER BY timestamp DESC - LIMIT ? - """ - df = pd.read_sql_query(query, conn, params=(limit,)) + df = pd.read_sql_query("SELECT * FROM strategies", conn) conn.close() return df -def get_engine_trades(limit=100): - """Get trades with strategy info from engine""" +def get_strategies_for_account(account_id): + conn = get_engine_connection() + df = pd.read_sql_query( + """SELECT s.*, a.symbols_json, a.enabled, a.assigned_at + FROM strategies s + JOIN account_strategies a ON s.id = a.strategy_id + WHERE a.account_id = ?""", + conn, + params=(account_id,) + ) + conn.close() + return df + + +def insert_strategy(name, strategy_type, params_json): + conn = get_engine_connection() + strategy_id = str(uuid.uuid4()) + now = datetime.utcnow().isoformat() + conn.execute( + "INSERT INTO strategies (id, name, strategy_type, params_json, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", + (strategy_id, name, strategy_type, json.dumps(params_json), now, now) + ) + conn.commit() + conn.close() + return strategy_id + + +def update_strategy(strategy_id, name, strategy_type, params_json): + conn = get_engine_connection() + now = datetime.utcnow().isoformat() + conn.execute( + "UPDATE strategies SET name = ?, strategy_type = ?, params_json = ?, updated_at = ? WHERE id = ?", + (name, strategy_type, json.dumps(params_json), now, strategy_id) + ) + conn.commit() + conn.close() + + +def delete_strategy(strategy_id): + conn = get_engine_connection() + conn.execute("DELETE FROM account_strategies WHERE strategy_id = ?", (strategy_id,)) + conn.execute("DELETE FROM strategies WHERE id = ?", (strategy_id,)) + conn.commit() + conn.close() + +# ============ Positions ============ + +def get_positions(account_id): conn = get_engine_connection() - query = """ - SELECT * FROM trades - ORDER BY timestamp DESC - LIMIT ? - """ - df = pd.read_sql_query(query, conn, params=(limit,)) + df = pd.read_sql_query( + "SELECT * FROM positions WHERE account_id = ? AND quantity != 0", + conn, + params=(account_id,) + ) conn.close() return df +# ============ Account Strategies ============ + +def assign_strategy_to_account(account_id, strategy_id, symbols): + conn = get_engine_connection() + conn.execute( + "INSERT OR REPLACE INTO account_strategies (account_id, strategy_id, symbols_json, enabled) VALUES (?, ?, ?, TRUE)", + (account_id, strategy_id, json.dumps(symbols)) + ) + conn.commit() + conn.close() + + +def unassign_strategy_from_account(account_id, strategy_id): + conn = get_engine_connection() + conn.execute( + "DELETE FROM account_strategies WHERE account_id = ? AND strategy_id = ?", + (account_id, strategy_id) + ) + conn.commit() + conn.close() + +def update_assignment_symbols(account_id, strategy_id, symbols): + conn = get_engine_connection() + conn.execute( + "UPDATE account_strategies SET symbols_json = ? WHERE account_id = ? AND strategy_id = ?", + (json.dumps(symbols), account_id, strategy_id) + ) + conn.commit() + conn.close() -def get_engine_snapshots(limit=100): - """Get account snapshots from engine (for comparison)""" + +def toggle_strategy_enabled(account_id, strategy_id, enabled): conn = get_engine_connection() - query = """ - SELECT * FROM account_snapshots - ORDER BY timestamp DESC - LIMIT ? - """ - df = pd.read_sql_query(query, conn, params=(limit,)) + conn.execute( + "UPDATE account_strategies SET enabled = ? WHERE account_id = ? AND strategy_id = ?", + (enabled, account_id, strategy_id) + ) + conn.commit() + conn.close() + + +def get_all_assignments(): + conn = get_engine_connection() + df = pd.read_sql_query( + """SELECT a.account_id, a.strategy_id, a.symbols_json, a.enabled, a.assigned_at, s.name, s.strategy_type + FROM account_strategies a + JOIN strategies s ON s.id = a.strategy_id""", + conn + ) conn.close() return df -# ============ Combined Stats ============ +# ============ Stats ============ + +def get_account_stats(account_id): + latest = get_latest_snapshot(account_id) + if not latest: + return None -def get_account_stats(): - """Get combined stats from both databases""" - account = get_account() - positions = get_positions() - - # From engine - engine_conn = get_engine_connection() - cursor = engine_conn.cursor() - - cursor.execute("SELECT COUNT(*) FROM trades") + conn = get_engine_connection() + cursor = conn.cursor() + + cursor.execute("SELECT COUNT(*) FROM trades WHERE account_id = ?", (account_id,)) total_trades = cursor.fetchone()[0] - - cursor.execute("SELECT COUNT(*) FROM signals") + + cursor.execute("SELECT COUNT(*) FROM signals WHERE account_id = ?", (account_id,)) total_signals = cursor.fetchone()[0] - - engine_conn.close() - - # Calculate P&L - if not positions.empty: - positions['pnl'] = (positions['current_price'] - positions['avg_entry_price']) * positions['quantity'] - total_pnl = positions['pnl'].sum() - else: - total_pnl = 0 - + + conn.close() + return { - 'equity': account['equity'], - 'cash': account['cash'], - 'buying_power': account['buying_power'], + 'account_id': account_id, + 'equity': latest['equity'], + 'cash': latest['cash'], + 'buying_power': latest['buying_power'], 'total_trades': total_trades, 'total_signals': total_signals, - 'positions_count': len(positions), - 'unrealized_pnl': total_pnl, } -def get_strategy_performance(): - """Get performance by strategy from engine""" +def get_strategy_performance(account_id=None): conn = get_engine_connection() - query = """ - SELECT - strategy, - COUNT(*) as trade_count, - SUM(CASE WHEN side = 'buy' THEN quantity * price ELSE 0 END) as total_bought, - SUM(CASE WHEN side = 'sell' THEN quantity * price ELSE 0 END) as total_sold - FROM trades - WHERE strategy IS NOT NULL - GROUP BY strategy - """ - df = pd.read_sql_query(query, conn) + if account_id: + df = pd.read_sql_query( + """SELECT strategy, COUNT(*) as trade_count, + SUM(CASE WHEN side = 'buy' THEN quantity * price ELSE 0 END) as total_bought, + SUM(CASE WHEN side = 'sell' THEN quantity * price ELSE 0 END) as total_sold + FROM trades WHERE strategy IS NOT NULL AND account_id = ? + GROUP BY strategy""", + conn, + params=(account_id,) + ) + else: + df = pd.read_sql_query( + """SELECT account_id, strategy, COUNT(*) as trade_count, + SUM(CASE WHEN side = 'buy' THEN quantity * price ELSE 0 END) as total_bought, + SUM(CASE WHEN side = 'sell' THEN quantity * price ELSE 0 END) as total_sold + FROM trades WHERE strategy IS NOT NULL + GROUP BY account_id, strategy""", + conn + ) conn.close() return df -def get_signal_stats(): - """Get signal statistics from engine""" +def get_signal_stats(account_id=None): conn = get_engine_connection() - query = """ - SELECT - strategy, - signal_type, - COUNT(*) as count, - AVG(strength) as avg_strength - FROM signals - GROUP BY strategy, signal_type - """ - df = pd.read_sql_query(query, conn) + if account_id: + df = pd.read_sql_query( + """SELECT strategy, signal_type, COUNT(*) as count, AVG(strength) as avg_strength + FROM signals WHERE account_id = ? + GROUP BY strategy, signal_type""", + conn, + params=(account_id,) + ) + else: + df = pd.read_sql_query( + """SELECT account_id, strategy, signal_type, COUNT(*) as count, AVG(strength) as avg_strength + FROM signals + GROUP BY account_id, strategy, signal_type""", + conn + ) conn.close() return df diff --git a/tests/broker_tests.rs b/tests/broker_tests.rs index 702274a..b43dbff 100644 --- a/tests/broker_tests.rs +++ b/tests/broker_tests.rs @@ -4,7 +4,7 @@ use cubert::types::{Order, OrderType, Side}; #[tokio::test] async fn test_paper_broker_buy_order() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_buy", 100_000.0); broker.set_price("AAPL", 150.0); let order = Order { @@ -18,7 +18,7 @@ async fn test_paper_broker_buy_order() { assert!(order_id.starts_with("PAPER-")); let account = broker.get_account().await.unwrap(); - assert_eq!(account.cash, 85_000.0); // 100k - (100 * 150) + assert_eq!(account.cash, 85_000.0); assert_eq!(account.equity, 100_000.0); let position = broker.get_position("AAPL").await.unwrap().unwrap(); @@ -28,10 +28,9 @@ async fn test_paper_broker_buy_order() { #[tokio::test] async fn test_paper_broker_sell_order() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_sell", 100_000.0); broker.set_price("AAPL", 150.0); - // First buy let buy_order = Order { symbol: "AAPL".to_string(), side: Side::Buy, @@ -40,7 +39,6 @@ async fn test_paper_broker_sell_order() { }; broker.submit_order(&buy_order).await.unwrap(); - // Update price and sell broker.set_price("AAPL", 160.0); let sell_order = Order { @@ -52,7 +50,7 @@ async fn test_paper_broker_sell_order() { broker.submit_order(&sell_order).await.unwrap(); let account = broker.get_account().await.unwrap(); - assert_eq!(account.cash, 101_000.0); // 85k + (100 * 160) + assert_eq!(account.cash, 101_000.0); assert_eq!(account.equity, 101_000.0); let position = broker.get_position("AAPL").await.unwrap(); @@ -61,13 +59,13 @@ async fn test_paper_broker_sell_order() { #[tokio::test] async fn test_paper_broker_insufficient_funds() { - let broker = PaperBroker::new(1_000.0); + let broker = PaperBroker::new("test_funds", 1_000.0); broker.set_price("AAPL", 150.0); let order = Order { symbol: "AAPL".to_string(), side: Side::Buy, - quantity: 100.0, // Would cost $15k + quantity: 100.0, order_type: OrderType::Market, }; @@ -77,10 +75,9 @@ async fn test_paper_broker_insufficient_funds() { #[tokio::test] async fn test_paper_broker_insufficient_shares() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_shares", 100_000.0); broker.set_price("AAPL", 150.0); - // Buy 50 shares let buy_order = Order { symbol: "AAPL".to_string(), side: Side::Buy, @@ -89,7 +86,6 @@ async fn test_paper_broker_insufficient_shares() { }; broker.submit_order(&buy_order).await.unwrap(); - // Try to sell 100 let sell_order = Order { symbol: "AAPL".to_string(), side: Side::Sell, @@ -103,9 +99,8 @@ async fn test_paper_broker_insufficient_shares() { #[tokio::test] async fn test_paper_broker_position_averaging() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_avg", 100_000.0); - // Buy 100 @ $100 broker.set_price("AAPL", 100.0); let order1 = Order { symbol: "AAPL".to_string(), @@ -115,7 +110,6 @@ async fn test_paper_broker_position_averaging() { }; broker.submit_order(&order1).await.unwrap(); - // Buy 100 @ $120 broker.set_price("AAPL", 120.0); let order2 = Order { symbol: "AAPL".to_string(), @@ -127,12 +121,12 @@ async fn test_paper_broker_position_averaging() { let position = broker.get_position("AAPL").await.unwrap().unwrap(); assert_eq!(position.quantity, 200.0); - assert_eq!(position.avg_entry_price, 110.0); // ($10k + $12k) / 200 + assert_eq!(position.avg_entry_price, 110.0); } #[tokio::test] async fn test_paper_broker_update_prices() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_prices", 100_000.0); broker.set_price("AAPL", 150.0); let order = Order { @@ -143,12 +137,11 @@ async fn test_paper_broker_update_prices() { }; broker.submit_order(&order).await.unwrap(); - // Update prices broker.set_price("AAPL", 200.0); broker.update_prices(&["AAPL".to_string()]).await.unwrap(); let account = broker.get_account().await.unwrap(); - assert_eq!(account.equity, 105_000.0); // 85k cash + 100 * 200 + assert_eq!(account.equity, 105_000.0); let position = broker.get_position("AAPL").await.unwrap().unwrap(); assert_eq!(position.current_price, 200.0); @@ -156,7 +149,7 @@ async fn test_paper_broker_update_prices() { #[tokio::test] async fn test_paper_broker_multiple_positions() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_multi", 100_000.0); broker.set_price("AAPL", 150.0); broker.set_price("MSFT", 300.0); @@ -181,16 +174,15 @@ async fn test_paper_broker_multiple_positions() { assert_eq!(positions.len(), 2); let account = broker.get_account().await.unwrap(); - assert_eq!(account.cash, 70_000.0); // 100k - 15k - 15k + assert_eq!(account.cash, 70_000.0); assert_eq!(account.equity, 100_000.0); } #[tokio::test] async fn test_paper_broker_partial_sell() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("test_partial", 100_000.0); broker.set_price("AAPL", 100.0); - // Buy 100 shares let buy_order = Order { symbol: "AAPL".to_string(), side: Side::Buy, @@ -199,7 +191,6 @@ async fn test_paper_broker_partial_sell() { }; broker.submit_order(&buy_order).await.unwrap(); - // Sell 40 shares let sell_order = Order { symbol: "AAPL".to_string(), side: Side::Sell, @@ -212,13 +203,12 @@ async fn test_paper_broker_partial_sell() { assert_eq!(position.quantity, 60.0); let account = broker.get_account().await.unwrap(); - assert_eq!(account.cash, 94_000.0); // 90k + 4k + assert_eq!(account.cash, 94_000.0); } #[tokio::test] async fn test_paper_broker_limit_order() { - let broker = PaperBroker::new(100_000.0); - // Don't set market price - limit order uses its own price + let broker = PaperBroker::new("test_limit", 100_000.0); let order = Order { symbol: "AAPL".to_string(), @@ -230,7 +220,7 @@ async fn test_paper_broker_limit_order() { broker.submit_order(&order).await.unwrap(); let account = broker.get_account().await.unwrap(); - assert_eq!(account.cash, 85_500.0); // 100k - (100 * 145) + assert_eq!(account.cash, 85_500.0); let position = broker.get_position("AAPL").await.unwrap().unwrap(); assert_eq!(position.avg_entry_price, 145.0); diff --git a/tests/engine_tests.rs b/tests/engine_tests.rs index e05c18c..2e849a8 100644 --- a/tests/engine_tests.rs +++ b/tests/engine_tests.rs @@ -1,137 +1,395 @@ use cubert::storage::Storage; -#[tokio::test] -async fn test_full_trade_cycle() { - // Use in-memory database for tests +const ACCT: &str = "test_paper"; + +async fn setup() -> Storage { let storage = Storage::connect("sqlite::memory:").await.unwrap(); storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + storage.init_account(ACCT, 100_000.0).await.unwrap(); + storage +} - // Verify initial state - let account = storage.get_account().await.unwrap(); +#[tokio::test] +async fn test_full_trade_cycle() { + let storage = setup().await; + + let account = storage.get_account(ACCT).await.unwrap(); assert_eq!(account.cash, 100_000.0); assert_eq!(account.equity, 100_000.0); - // Simulate a buy let buy_qty = 100.0; let buy_price = 150.0; let order_value = buy_qty * buy_price; - storage.deduct_cash(order_value).await.unwrap(); + storage.deduct_cash(ACCT, order_value).await.unwrap(); storage - .upsert_position("AAPL", buy_qty, buy_price, buy_price) + .upsert_position(ACCT, "AAPL", buy_qty, buy_price, buy_price) .await .unwrap(); storage .insert_trade( + ACCT, "AAPL", "buy", buy_qty, buy_price, Some("test-1"), - Some("test"), + Some("momentum"), ) .await .unwrap(); - // Verify after buy - let account = storage.get_account().await.unwrap(); + let account = storage.get_account(ACCT).await.unwrap(); assert_eq!(account.cash, 85_000.0); - assert_eq!(account.equity, 100_000.0); // Cash + position + assert_eq!(account.equity, 100_000.0); - // Simulate price increase - storage.update_position_price("AAPL", 160.0).await.unwrap(); + storage + .update_position_price(ACCT, "AAPL", 160.0) + .await + .unwrap(); - let account = storage.get_account().await.unwrap(); - assert_eq!(account.equity, 101_000.0); // $85k cash + $16k position + let account = storage.get_account(ACCT).await.unwrap(); + assert_eq!(account.equity, 101_000.0); - // Simulate a sell let sell_qty = 100.0; let sell_price = 160.0; let sell_value = sell_qty * sell_price; - storage.add_cash(sell_value).await.unwrap(); - storage.delete_position("AAPL").await.unwrap(); + storage.add_cash(ACCT, sell_value).await.unwrap(); + storage.delete_position(ACCT, "AAPL").await.unwrap(); storage .insert_trade( + ACCT, "AAPL", "sell", sell_qty, sell_price, Some("test-2"), - Some("test"), + Some("momentum"), ) .await .unwrap(); - // Verify after sell - let account = storage.get_account().await.unwrap(); + let account = storage.get_account(ACCT).await.unwrap(); assert_eq!(account.cash, 101_000.0); assert_eq!(account.equity, 101_000.0); - // Verify trades recorded - let trades = storage.get_trades(10).await.unwrap(); + let trades = storage.get_trades_by_account(ACCT, 10).await.unwrap(); assert_eq!(trades.len(), 2); } #[tokio::test] async fn test_multiple_positions() { - let storage = Storage::connect("sqlite::memory:").await.unwrap(); - storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + let storage = setup().await; - // Buy AAPL - storage.deduct_cash(15_000.0).await.unwrap(); + storage.deduct_cash(ACCT, 15_000.0).await.unwrap(); storage - .upsert_position("AAPL", 100.0, 150.0, 150.0) + .upsert_position(ACCT, "AAPL", 100.0, 150.0, 150.0) .await .unwrap(); - // Buy MSFT - storage.deduct_cash(30_000.0).await.unwrap(); + storage.deduct_cash(ACCT, 30_000.0).await.unwrap(); storage - .upsert_position("MSFT", 100.0, 300.0, 300.0) + .upsert_position(ACCT, "MSFT", 100.0, 300.0, 300.0) .await .unwrap(); - // Verify - let account = storage.get_account().await.unwrap(); + let account = storage.get_account(ACCT).await.unwrap(); assert_eq!(account.cash, 55_000.0); assert_eq!(account.equity, 100_000.0); - let positions = storage.get_positions().await.unwrap(); + let positions = storage.get_positions(ACCT).await.unwrap(); assert_eq!(positions.len(), 2); } #[tokio::test] async fn test_position_averaging() { - let storage = Storage::connect("sqlite::memory:").await.unwrap(); - storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + let storage = setup().await; - // First buy: 100 shares at $100 - storage.deduct_cash(10_000.0).await.unwrap(); + storage.deduct_cash(ACCT, 10_000.0).await.unwrap(); storage - .upsert_position("AAPL", 100.0, 100.0, 100.0) + .upsert_position(ACCT, "AAPL", 100.0, 100.0, 100.0) .await .unwrap(); - // Second buy: 100 shares at $120 (average in) - storage.deduct_cash(12_000.0).await.unwrap(); + storage.deduct_cash(ACCT, 12_000.0).await.unwrap(); - // Calculate new average - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); + let pos = storage.get_position(ACCT, "AAPL").await.unwrap().unwrap(); let total_qty = pos.quantity + 100.0; let total_cost = (pos.quantity * pos.avg_entry_price) + 12_000.0; let new_avg = total_cost / total_qty; storage - .upsert_position("AAPL", total_qty, new_avg, 120.0) + .upsert_position(ACCT, "AAPL", total_qty, new_avg, 120.0) .await .unwrap(); - // Verify - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); + let pos = storage.get_position(ACCT, "AAPL").await.unwrap().unwrap(); assert_eq!(pos.quantity, 200.0); - assert_eq!(pos.avg_entry_price, 110.0); // ($10k + $12k) / 200 shares + assert_eq!(pos.avg_entry_price, 110.0); +} + +#[tokio::test] +async fn test_account_isolation() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage.init_account("aggressive", 100_000.0).await.unwrap(); + storage + .init_account("conservative", 50_000.0) + .await + .unwrap(); + + storage.deduct_cash("aggressive", 15_000.0).await.unwrap(); + storage + .upsert_position("aggressive", "AAPL", 100.0, 150.0, 150.0) + .await + .unwrap(); + + storage.deduct_cash("conservative", 6_000.0).await.unwrap(); + storage + .upsert_position("conservative", "MSFT", 20.0, 300.0, 300.0) + .await + .unwrap(); + + let agg = storage.get_account("aggressive").await.unwrap(); + assert_eq!(agg.cash, 85_000.0); + assert_eq!(agg.equity, 100_000.0); + assert_eq!(agg.id, "aggressive"); + + let con = storage.get_account("conservative").await.unwrap(); + assert_eq!(con.cash, 44_000.0); + assert_eq!(con.equity, 50_000.0); + assert_eq!(con.id, "conservative"); + + let agg_positions = storage.get_positions("aggressive").await.unwrap(); + assert_eq!(agg_positions.len(), 1); + assert_eq!(agg_positions[0].symbol, "AAPL"); + + let con_positions = storage.get_positions("conservative").await.unwrap(); + assert_eq!(con_positions.len(), 1); + assert_eq!(con_positions[0].symbol, "MSFT"); +} + +#[tokio::test] +async fn test_same_symbol_different_accounts() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage.init_account("acct_a", 100_000.0).await.unwrap(); + storage.init_account("acct_b", 100_000.0).await.unwrap(); + + storage + .upsert_position("acct_a", "AAPL", 100.0, 150.0, 150.0) + .await + .unwrap(); + storage + .upsert_position("acct_b", "AAPL", 50.0, 155.0, 155.0) + .await + .unwrap(); + + let pos_a = storage + .get_position("acct_a", "AAPL") + .await + .unwrap() + .unwrap(); + assert_eq!(pos_a.quantity, 100.0); + assert_eq!(pos_a.avg_entry_price, 150.0); + + let pos_b = storage + .get_position("acct_b", "AAPL") + .await + .unwrap() + .unwrap(); + assert_eq!(pos_b.quantity, 50.0); + assert_eq!(pos_b.avg_entry_price, 155.0); + + storage.delete_position("acct_a", "AAPL").await.unwrap(); + + let pos_a = storage.get_position("acct_a", "AAPL").await.unwrap(); + assert!(pos_a.is_none()); + + let pos_b = storage + .get_position("acct_b", "AAPL") + .await + .unwrap() + .unwrap(); + assert_eq!(pos_b.quantity, 50.0); +} + +#[tokio::test] +async fn test_trades_scoped_to_account() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage + .insert_trade( + "acct_x", + "AAPL", + "buy", + 100.0, + 150.0, + Some("o1"), + Some("momentum"), + ) + .await + .unwrap(); + storage + .insert_trade( + "acct_x", + "MSFT", + "buy", + 50.0, + 300.0, + Some("o2"), + Some("momentum"), + ) + .await + .unwrap(); + storage + .insert_trade( + "acct_y", + "AAPL", + "buy", + 25.0, + 148.0, + Some("o3"), + Some("mean_rev"), + ) + .await + .unwrap(); + + let all = storage.get_trades(10).await.unwrap(); + assert_eq!(all.len(), 3); + + let x_trades = storage.get_trades_by_account("acct_x", 10).await.unwrap(); + assert_eq!(x_trades.len(), 2); + + let y_trades = storage.get_trades_by_account("acct_y", 10).await.unwrap(); + assert_eq!(y_trades.len(), 1); + assert_eq!(y_trades[0].symbol, "AAPL"); + assert_eq!(y_trades[0].strategy.as_ref().unwrap(), "mean_rev"); +} + +#[tokio::test] +async fn test_signals_scoped_to_account() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage + .insert_signal("acct_alpha", "AAPL", "buy", 0.9, "momentum") + .await + .unwrap(); + storage + .insert_signal("acct_alpha", "TSLA", "sell", 0.7, "momentum") + .await + .unwrap(); + storage + .insert_signal("acct_beta", "GOOGL", "buy", 0.5, "mean_rev") + .await + .unwrap(); + + let alpha = storage + .get_signals_by_account("acct_alpha", 10) + .await + .unwrap(); + assert_eq!(alpha.len(), 2); + + let beta = storage + .get_signals_by_account("acct_beta", 10) + .await + .unwrap(); + assert_eq!(beta.len(), 1); + assert_eq!(beta[0].symbol, "GOOGL"); +} + +#[tokio::test] +async fn test_snapshots_scoped_to_account() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage + .insert_account_snapshot("acct_one", 100_000.0, 80_000.0, 80_000.0) + .await + .unwrap(); + storage + .insert_account_snapshot("acct_one", 102_000.0, 82_000.0, 82_000.0) + .await + .unwrap(); + storage + .insert_account_snapshot("acct_two", 50_000.0, 50_000.0, 50_000.0) + .await + .unwrap(); + + let one = storage.get_account_history("acct_one", 10).await.unwrap(); + assert_eq!(one.len(), 2); + assert_eq!(one[0].equity, 102_000.0); + + let two = storage.get_account_history("acct_two", 10).await.unwrap(); + assert_eq!(two.len(), 1); + assert_eq!(two[0].equity, 50_000.0); +} + +#[tokio::test] +async fn test_init_account_idempotent() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage.init_account("paper_1", 100_000.0).await.unwrap(); + storage.init_account("paper_1", 50_000.0).await.unwrap(); + + let account = storage.get_account("paper_1").await.unwrap(); + assert_eq!(account.cash, 100_000.0); +} + +#[tokio::test] +async fn test_insufficient_funds_per_account() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage.init_account("rich", 100_000.0).await.unwrap(); + storage.init_account("poor", 1_000.0).await.unwrap(); + + let result = storage.deduct_cash("poor", 5_000.0).await; + assert!(result.is_err()); + + let rich = storage.get_account("rich").await.unwrap(); + assert_eq!(rich.cash, 100_000.0); +} + +#[tokio::test] +async fn test_price_update_scoped_to_account() { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + + storage.init_account("acct_a", 100_000.0).await.unwrap(); + storage.init_account("acct_b", 100_000.0).await.unwrap(); + + storage + .upsert_position("acct_a", "AAPL", 100.0, 150.0, 150.0) + .await + .unwrap(); + storage + .upsert_position("acct_b", "AAPL", 200.0, 145.0, 145.0) + .await + .unwrap(); + + storage + .update_position_price("acct_a", "AAPL", 160.0) + .await + .unwrap(); + + let pos_a = storage + .get_position("acct_a", "AAPL") + .await + .unwrap() + .unwrap(); + assert_eq!(pos_a.current_price, 160.0); + + let pos_b = storage + .get_position("acct_b", "AAPL") + .await + .unwrap() + .unwrap(); + assert_eq!(pos_b.current_price, 145.0); } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index a4cf02b..a41acbe 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -19,24 +19,24 @@ fn default_risk_config() -> RiskConfig { async fn test_full_trade_cycle() { let storage = Storage::connect("sqlite::memory:").await.unwrap(); storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + storage.init_account("a", 100_000.0).await.unwrap(); - let account = storage.get_account().await.unwrap(); + let account = storage.get_account("a").await.unwrap(); assert_eq!(account.cash, 100_000.0); assert_eq!(account.equity, 100_000.0); - // Simulate a buy let buy_qty = 100.0; let buy_price = 150.0; let order_value = buy_qty * buy_price; - storage.deduct_cash(order_value).await.unwrap(); + storage.deduct_cash("a", order_value).await.unwrap(); storage - .upsert_position("AAPL", buy_qty, buy_price, buy_price) + .upsert_position("a", "AAPL", buy_qty, buy_price, buy_price) .await .unwrap(); storage .insert_trade( + "acct_cycle", "AAPL", "buy", buy_qty, @@ -47,25 +47,27 @@ async fn test_full_trade_cycle() { .await .unwrap(); - let account = storage.get_account().await.unwrap(); + let account = storage.get_account("a").await.unwrap(); assert_eq!(account.cash, 85_000.0); assert_eq!(account.equity, 100_000.0); - // Simulate price increase - storage.update_position_price("AAPL", 160.0).await.unwrap(); + storage + .update_position_price("a", "AAPL", 160.0) + .await + .unwrap(); - let account = storage.get_account().await.unwrap(); + let account = storage.get_account("a").await.unwrap(); assert_eq!(account.equity, 101_000.0); - // Simulate a sell let sell_qty = 100.0; let sell_price = 160.0; let sell_value = sell_qty * sell_price; - storage.add_cash(sell_value).await.unwrap(); - storage.delete_position("AAPL").await.unwrap(); + storage.add_cash("a", sell_value).await.unwrap(); + storage.delete_position("a", "AAPL").await.unwrap(); storage .insert_trade( + "acct_cycle", "AAPL", "sell", sell_qty, @@ -76,7 +78,7 @@ async fn test_full_trade_cycle() { .await .unwrap(); - let account = storage.get_account().await.unwrap(); + let account = storage.get_account("a").await.unwrap(); assert_eq!(account.cash, 101_000.0); assert_eq!(account.equity, 101_000.0); @@ -88,25 +90,25 @@ async fn test_full_trade_cycle() { async fn test_multiple_positions() { let storage = Storage::connect("sqlite::memory:").await.unwrap(); storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + storage.init_account("a", 100_000.0).await.unwrap(); - storage.deduct_cash(15_000.0).await.unwrap(); + storage.deduct_cash("a", 15_000.0).await.unwrap(); storage - .upsert_position("AAPL", 100.0, 150.0, 150.0) + .upsert_position("a", "AAPL", 100.0, 150.0, 150.0) .await .unwrap(); - storage.deduct_cash(30_000.0).await.unwrap(); + storage.deduct_cash("a", 30_000.0).await.unwrap(); storage - .upsert_position("MSFT", 100.0, 300.0, 300.0) + .upsert_position("a", "MSFT", 100.0, 300.0, 300.0) .await .unwrap(); - let account = storage.get_account().await.unwrap(); + let account = storage.get_account("a").await.unwrap(); assert_eq!(account.cash, 55_000.0); assert_eq!(account.equity, 100_000.0); - let positions = storage.get_positions().await.unwrap(); + let positions = storage.get_positions("a").await.unwrap(); assert_eq!(positions.len(), 2); } @@ -114,34 +116,34 @@ async fn test_multiple_positions() { async fn test_position_averaging() { let storage = Storage::connect("sqlite::memory:").await.unwrap(); storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + storage.init_account("a", 100_000.0).await.unwrap(); - storage.deduct_cash(10_000.0).await.unwrap(); + storage.deduct_cash("a", 10_000.0).await.unwrap(); storage - .upsert_position("AAPL", 100.0, 100.0, 100.0) + .upsert_position("a", "AAPL", 100.0, 100.0, 100.0) .await .unwrap(); - storage.deduct_cash(12_000.0).await.unwrap(); + storage.deduct_cash("a", 12_000.0).await.unwrap(); - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); + let pos = storage.get_position("a", "AAPL").await.unwrap().unwrap(); let total_qty = pos.quantity + 100.0; let total_cost = (pos.quantity * pos.avg_entry_price) + 12_000.0; let new_avg = total_cost / total_qty; storage - .upsert_position("AAPL", total_qty, new_avg, 120.0) + .upsert_position("a", "AAPL", total_qty, new_avg, 120.0) .await .unwrap(); - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); + let pos = storage.get_position("a", "AAPL").await.unwrap().unwrap(); assert_eq!(pos.quantity, 200.0); assert_eq!(pos.avg_entry_price, 110.0); } #[tokio::test] async fn test_risk_manager_with_broker() { - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("acct_risk", 100_000.0); let mut rm = RiskManager::new(default_risk_config()); broker.set_price("AAPL", 150.0); @@ -170,47 +172,41 @@ async fn test_risk_manager_with_broker() { #[tokio::test] async fn test_signal_to_trade_flow() { - // Setup let storage = Storage::connect("sqlite::memory:").await.unwrap(); storage.migrate().await.unwrap(); - storage.init_account(100_000.0).await.unwrap(); + storage.init_account("a", 100_000.0).await.unwrap(); - let broker = PaperBroker::new(100_000.0); + let broker = PaperBroker::new("acct_flow", 100_000.0); let mut rm = RiskManager::new(default_risk_config()); broker.set_price("AAPL", 150.0); - // 1. Generate signal let signal = Signal::Buy { symbol: "AAPL".to_string(), strength: 0.8, }; - // 2. Log signal storage - .insert_signal("AAPL", "buy", 0.8, "momentum") + .insert_signal("acct_flow", "AAPL", "buy", 0.8, "momentum") .await .unwrap(); - // 3. Get account/positions let account = broker.get_account().await.unwrap(); let positions = broker.get_positions().await.unwrap(); - // 4. Risk check let order = rm .evaluate_signal(&signal, &account, &positions, 150.0) .unwrap(); - // 5. Execute order let order_id = broker.submit_order(&order).await.unwrap(); - // 6. Log trade let side = match order.side { Side::Buy => "buy", Side::Sell => "sell", }; storage .insert_trade( + "acct_flow", "AAPL", side, order.quantity, @@ -221,53 +217,61 @@ async fn test_signal_to_trade_flow() { .await .unwrap(); - // 7. Snapshot account let account = broker.get_account().await.unwrap(); storage - .insert_account_snapshot(account.equity, account.cash, account.buying_power) + .insert_account_snapshot( + "acct_flow", + account.equity, + account.cash, + account.buying_power, + ) .await .unwrap(); - // Verify let signals = storage.get_signals(10).await.unwrap(); assert_eq!(signals.len(), 1); let trades = storage.get_trades(10).await.unwrap(); assert_eq!(trades.len(), 1); - let history = storage.get_account_history(10).await.unwrap(); + let history = storage.get_account_history("acct_flow", 10).await.unwrap(); assert_eq!(history.len(), 1); } #[tokio::test] async fn test_dual_database_separation() { - // Broker storage (simulates broker API) let broker_storage = Storage::connect("sqlite::memory:").await.unwrap(); broker_storage.migrate().await.unwrap(); - broker_storage.init_account(100_000.0).await.unwrap(); + broker_storage.init_account("a", 100_000.0).await.unwrap(); - // Engine storage (for logging) let engine_storage = Storage::connect("sqlite::memory:").await.unwrap(); engine_storage.migrate().await.unwrap(); - // Execute a trade via broker storage - broker_storage.deduct_cash(15_000.0).await.unwrap(); + broker_storage.deduct_cash("a", 15_000.0).await.unwrap(); broker_storage - .upsert_position("AAPL", 100.0, 150.0, 150.0) + .upsert_position("a", "AAPL", 100.0, 150.0, 150.0) .await .unwrap(); broker_storage - .insert_trade("AAPL", "buy", 100.0, 150.0, Some("order-1"), None) + .insert_trade( + "acct_dual", + "AAPL", + "buy", + 100.0, + 150.0, + Some("order-1"), + None, + ) .await .unwrap(); - // Log to engine storage engine_storage - .insert_signal("AAPL", "buy", 0.8, "momentum") + .insert_signal("acct_dual", "AAPL", "buy", 0.8, "momentum") .await .unwrap(); engine_storage .insert_trade( + "acct_dual", "AAPL", "buy", 100.0, @@ -278,13 +282,17 @@ async fn test_dual_database_separation() { .await .unwrap(); - let account = broker_storage.get_account().await.unwrap(); + let account = broker_storage.get_account("a").await.unwrap(); engine_storage - .insert_account_snapshot(account.equity, account.cash, account.buying_power) + .insert_account_snapshot( + "acct_dual", + account.equity, + account.cash, + account.buying_power, + ) .await .unwrap(); - // Verify separation let broker_trades = broker_storage.get_trades(10).await.unwrap(); assert_eq!(broker_trades.len(), 1); assert!(broker_trades[0].strategy.is_none()); @@ -296,7 +304,6 @@ async fn test_dual_database_separation() { let engine_signals = engine_storage.get_signals(10).await.unwrap(); assert_eq!(engine_signals.len(), 1); - // Broker storage should have no signals let broker_signals = broker_storage.get_signals(10).await.unwrap(); assert_eq!(broker_signals.len(), 0); } diff --git a/tests/risk_tests.rs b/tests/risk_tests.rs index 2307a10..04fd53a 100644 --- a/tests/risk_tests.rs +++ b/tests/risk_tests.rs @@ -14,6 +14,7 @@ fn default_risk_config() -> RiskConfig { fn default_account() -> Account { Account { + id: "test_default".to_string(), equity: 100_000.0, cash: 100_000.0, buying_power: 100_000.0, @@ -74,7 +75,7 @@ fn test_risk_manager_position_sizing() { let order = rm .evaluate_signal(&signal, &account, &positions, 100.0) .unwrap(); - assert_eq!(order.quantity, 100.0); // 10% of 100k at $100 + assert_eq!(order.quantity, 100.0); } #[test] diff --git a/tests/storage.rs b/tests/storage.rs deleted file mode 100644 index d5c324a..0000000 --- a/tests/storage.rs +++ /dev/null @@ -1,188 +0,0 @@ -use cubert::storage::Storage; - -async fn setup_test_storage() -> Storage { - // Use in-memory database - let storage = Storage::connect("sqlite::memory:").await.unwrap(); - storage.migrate().await.unwrap(); - storage -} - -#[tokio::test] -async fn test_account_init() { - let storage = setup_test_storage().await; - - // Init account - storage.init_account(100_000.0).await.unwrap(); - - // Get account - let account = storage.get_account().await.unwrap(); - assert_eq!(account.cash, 100_000.0); - assert_eq!(account.equity, 100_000.0); -} - -#[tokio::test] -async fn test_account_init_idempotent() { - let storage = setup_test_storage().await; - - // Init twice with different amounts - storage.init_account(100_000.0).await.unwrap(); - storage.init_account(50_000.0).await.unwrap(); // Should be ignored - - // Should still be first value - let account = storage.get_account().await.unwrap(); - assert_eq!(account.cash, 100_000.0); -} - -#[tokio::test] -async fn test_cash_operations() { - let storage = setup_test_storage().await; - storage.init_account(10_000.0).await.unwrap(); - - // Deduct - let remaining = storage.deduct_cash(3_000.0).await.unwrap(); - assert_eq!(remaining, 7_000.0); - - // Add - let new_total = storage.add_cash(1_000.0).await.unwrap(); - assert_eq!(new_total, 8_000.0); -} - -#[tokio::test] -async fn test_insufficient_funds() { - let storage = setup_test_storage().await; - storage.init_account(1_000.0).await.unwrap(); - - // Try to deduct more than available - let result = storage.deduct_cash(5_000.0).await; - assert!(result.is_err()); -} - -#[tokio::test] -async fn test_position_crud() { - let storage = setup_test_storage().await; - storage.init_account(100_000.0).await.unwrap(); - - // Create position - storage - .upsert_position("AAPL", 100.0, 150.0, 155.0) - .await - .unwrap(); - - // Read position - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); - assert_eq!(pos.symbol, "AAPL"); - assert_eq!(pos.quantity, 100.0); - assert_eq!(pos.avg_entry_price, 150.0); - - // Update position - storage - .upsert_position("AAPL", 150.0, 152.0, 160.0) - .await - .unwrap(); - let pos = storage.get_position("AAPL").await.unwrap().unwrap(); - assert_eq!(pos.quantity, 150.0); - - // Delete position - storage.delete_position("AAPL").await.unwrap(); - let pos = storage.get_position("AAPL").await.unwrap(); - assert!(pos.is_none()); -} - -#[tokio::test] -async fn test_equity_calculation() { - let storage = setup_test_storage().await; - storage.init_account(50_000.0).await.unwrap(); - - // Add a position worth $15,000 - storage - .upsert_position("AAPL", 100.0, 150.0, 150.0) - .await - .unwrap(); - - let account = storage.get_account().await.unwrap(); - assert_eq!(account.cash, 50_000.0); - assert_eq!(account.equity, 65_000.0); // cash + position value -} - -#[tokio::test] -async fn test_trade_insert_and_query() { - let storage = setup_test_storage().await; - - // Insert trades - storage - .insert_trade( - "AAPL", - "buy", - 100.0, - 150.0, - Some("order-1"), - Some("momentum"), - ) - .await - .unwrap(); - storage - .insert_trade( - "MSFT", - "buy", - 50.0, - 300.0, - Some("order-2"), - Some("momentum"), - ) - .await - .unwrap(); - storage - .insert_trade( - "AAPL", - "sell", - 100.0, - 160.0, - Some("order-3"), - Some("momentum"), - ) - .await - .unwrap(); - - // Query all trades - let trades = storage.get_trades(10).await.unwrap(); - assert_eq!(trades.len(), 3); - - // Query by symbol - let aapl_trades = storage.get_trades_by_symbol("AAPL", 10).await.unwrap(); - assert_eq!(aapl_trades.len(), 2); -} - -#[tokio::test] -async fn test_signal_insert_and_query() { - let storage = setup_test_storage().await; - - storage - .insert_signal("AAPL", "buy", 0.8, "momentum") - .await - .unwrap(); - storage - .insert_signal("MSFT", "sell", 0.6, "mean_reversion") - .await - .unwrap(); - - let signals = storage.get_signals(10).await.unwrap(); - assert_eq!(signals.len(), 2); -} - -#[tokio::test] -async fn test_account_snapshot() { - let storage = setup_test_storage().await; - - storage - .insert_account_snapshot(100_000.0, 80_000.0, 80_000.0) - .await - .unwrap(); - storage - .insert_account_snapshot(101_000.0, 81_000.0, 81_000.0) - .await - .unwrap(); - - let history = storage.get_account_history(10).await.unwrap(); - assert_eq!(history.len(), 2); - assert_eq!(history[0].equity, 101_000.0); // Most recent first -} diff --git a/tests/storage_tests.rs b/tests/storage_tests.rs new file mode 100644 index 0000000..d9429cb --- /dev/null +++ b/tests/storage_tests.rs @@ -0,0 +1,281 @@ +use cubert::storage::Storage; + +async fn setup_test_storage() -> Storage { + let storage = Storage::connect("sqlite::memory:").await.unwrap(); + storage.migrate().await.unwrap(); + storage +} + +#[tokio::test] +async fn test_account_init() { + let storage = setup_test_storage().await; + storage.init_account("a", 100_000.0).await.unwrap(); + + let account = storage.get_account("a").await.unwrap(); + assert_eq!(account.cash, 100_000.0); + assert_eq!(account.equity, 100_000.0); +} + +#[tokio::test] +async fn test_account_init_idempotent() { + let storage = setup_test_storage().await; + + storage.init_account("a", 100_000.0).await.unwrap(); + storage.init_account("b", 50_000.0).await.unwrap(); + + let account = storage.get_account("a").await.unwrap(); + assert_eq!(account.cash, 100_000.0); +} + +#[tokio::test] +async fn test_cash_operations() { + let storage = setup_test_storage().await; + storage.init_account("a", 10_000.0).await.unwrap(); + + let remaining = storage.deduct_cash("a", 3_000.0).await.unwrap(); + assert_eq!(remaining, 7_000.0); + + let new_total = storage.add_cash("a", 1_000.0).await.unwrap(); + assert_eq!(new_total, 8_000.0); +} + +#[tokio::test] +async fn test_insufficient_funds() { + let storage = setup_test_storage().await; + storage.init_account("a", 1_000.0).await.unwrap(); + + let result = storage.deduct_cash("a", 5_000.0).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_position_crud() { + let storage = setup_test_storage().await; + storage.init_account("a", 100_000.0).await.unwrap(); + + storage + .upsert_position("a", "AAPL", 100.0, 150.0, 155.0) + .await + .unwrap(); + + let pos = storage.get_position("a", "AAPL").await.unwrap().unwrap(); + assert_eq!(pos.symbol, "AAPL"); + assert_eq!(pos.quantity, 100.0); + assert_eq!(pos.avg_entry_price, 150.0); + + storage + .upsert_position("a", "AAPL", 150.0, 152.0, 160.0) + .await + .unwrap(); + let pos = storage.get_position("a", "AAPL").await.unwrap().unwrap(); + assert_eq!(pos.quantity, 150.0); + + storage.delete_position("a", "AAPL").await.unwrap(); + let pos = storage.get_position("a", "AAPL").await.unwrap(); + assert!(pos.is_none()); +} + +#[tokio::test] +async fn test_equity_calculation() { + let storage = setup_test_storage().await; + storage.init_account("a", 50_000.0).await.unwrap(); + + storage + .upsert_position("a", "AAPL", 100.0, 150.0, 150.0) + .await + .unwrap(); + + let account = storage.get_account("a").await.unwrap(); + assert_eq!(account.cash, 50_000.0); + assert_eq!(account.equity, 65_000.0); +} + +#[tokio::test] +async fn test_trade_insert_and_query() { + let storage = setup_test_storage().await; + + storage + .insert_trade( + "acct_trades", + "AAPL", + "buy", + 100.0, + 150.0, + Some("order-1"), + Some("Momentum"), + ) + .await + .unwrap(); + storage + .insert_trade( + "acct_trades", + "MSFT", + "buy", + 50.0, + 300.0, + Some("order-2"), + Some("Momentum"), + ) + .await + .unwrap(); + storage + .insert_trade( + "acct_trades", + "AAPL", + "sell", + 100.0, + 160.0, + Some("order-3"), + Some("Momentum"), + ) + .await + .unwrap(); + + let trades = storage.get_trades(10).await.unwrap(); + assert_eq!(trades.len(), 3); + + let aapl_trades = storage.get_trades_by_symbol("AAPL", 10).await.unwrap(); + assert_eq!(aapl_trades.len(), 2); + + let acct_trades = storage + .get_trades_by_account("acct_trades", 10) + .await + .unwrap(); + assert_eq!(acct_trades.len(), 3); +} + +#[tokio::test] +async fn test_signal_insert_and_query() { + let storage = setup_test_storage().await; + + storage + .insert_signal("acct_signals", "AAPL", "buy", 0.8, "Momentum") + .await + .unwrap(); + storage + .insert_signal("acct_signals", "MSFT", "sell", 0.6, "MeanReversion") + .await + .unwrap(); + + let signals = storage.get_signals(10).await.unwrap(); + assert_eq!(signals.len(), 2); + + let acct_signals = storage + .get_signals_by_account("acct_signals", 10) + .await + .unwrap(); + assert_eq!(acct_signals.len(), 2); +} + +#[tokio::test] +async fn test_account_snapshot() { + let storage = setup_test_storage().await; + + storage + .insert_account_snapshot("acct_snap", 100_000.0, 80_000.0, 80_000.0) + .await + .unwrap(); + storage + .insert_account_snapshot("acct_snap", 101_000.0, 81_000.0, 81_000.0) + .await + .unwrap(); + + let history = storage.get_account_history("acct_snap", 10).await.unwrap(); + assert_eq!(history.len(), 2); + assert_eq!(history[0].equity, 101_000.0); +} + +#[tokio::test] +async fn test_strategy_crud() { + let storage = setup_test_storage().await; + + let params = cubert::strategy::StrategyParams::Momentum { + lookback_period: 20, + threshold: 0.03, + startup_lookback: "1Hour".to_string(), + startup_bar_limit: 50, + }; + + let id = storage + .insert_strategy("fast_momentum", ¶ms) + .await + .unwrap(); + + let strat = storage.get_strategy_by_id(&id).await.unwrap(); + assert_eq!(strat.name, "fast_momentum"); + assert_eq!(strat.strategy_type, "Momentum"); + + let all = storage.get_strategies().await.unwrap(); + assert_eq!(all.len(), 1); + + storage + .update_strategy(&id, "slow_momentum", "Momentum", &strat.params_json) + .await + .unwrap(); + + let updated = storage.get_strategy_by_id(&id).await.unwrap(); + assert_eq!(updated.name, "slow_momentum"); + + storage.delete_strategy(&id).await.unwrap(); + let all = storage.get_strategies().await.unwrap(); + assert_eq!(all.len(), 0); +} + +#[tokio::test] +async fn test_account_strategy_assignment() { + let storage = setup_test_storage().await; + + let params = cubert::strategy::StrategyParams::Momentum { + lookback_period: 10, + threshold: 0.02, + startup_lookback: "5Min".to_string(), + startup_bar_limit: 100, + }; + + let strat_id = storage + .insert_strategy("test_strat", ¶ms) + .await + .unwrap(); + + storage + .assign_strategy_to_account( + "paper_aggressive", + &strat_id, + &["AAPL".to_string(), "TSLA".to_string()], + ) + .await + .unwrap(); + + let assignments = storage + .get_strategies_for_account("paper_aggressive") + .await + .unwrap(); + assert_eq!(assignments.len(), 1); + + let (strat, assignment) = &assignments[0]; + assert_eq!(strat.name, "test_strat"); + assert_eq!(assignment.account_id, "paper_aggressive"); + + let account_ids = storage.get_active_account_ids().await.unwrap(); + assert_eq!(account_ids.len(), 1); + assert_eq!(account_ids[0], "paper_aggressive"); + + storage + .toggle_strategy_enabled("paper_aggressive", &strat_id, false) + .await + .unwrap(); + + let account_ids = storage.get_active_account_ids().await.unwrap(); + assert_eq!(account_ids.len(), 0); + + storage + .unassign_strategy_from_account("paper_aggressive", &strat_id) + .await + .unwrap(); + + let assignments = storage + .get_strategies_for_account("paper_aggressive") + .await + .unwrap(); + assert_eq!(assignments.len(), 0); +} diff --git a/tests/strategy_tests.rs b/tests/strategy_tests.rs index ceeba5b..9382633 100644 --- a/tests/strategy_tests.rs +++ b/tests/strategy_tests.rs @@ -1,8 +1,8 @@ use std::time::SystemTime; -use cubert::strategy::Strategy; // Add this +use cubert::strategy::Strategy; use cubert::strategy::momentum::MomentumStrategy; -use cubert::types::{Bar, Signal}; // Add Bar here +use cubert::types::{Bar, Signal}; fn make_bar(symbol: &str, close: f64) -> Bar { Bar { @@ -18,24 +18,35 @@ fn make_bar(symbol: &str, close: f64) -> Bar { #[test] fn test_momentum_strategy_creation() { - let strategy = MomentumStrategy::new("test_momentum", vec!["AAPL".to_string()], 5, 0.02); + let strategy = MomentumStrategy::new( + "aggressive_momentum", + vec!["AAPL".to_string(), "MSFT".to_string()], + 3, + 0.01, + "5Min".to_string(), + 100, + ); - assert_eq!(strategy.name(), "test_momentum"); - assert_eq!(strategy.symbols(), &["AAPL".to_string()]); + assert_eq!(strategy.name(), "aggressive_momentum"); + assert_eq!( + strategy.symbols(), + &["AAPL".to_string(), "MSFT".to_string()] + ); } #[test] fn test_momentum_needs_warmup() { let mut strategy = MomentumStrategy::new( - "test", - vec!["AAPL".to_string()], - 5, // lookback period - 0.02, + "slow_momentum", + vec!["GOOGL".to_string()], + 8, + 0.03, + "1Day".to_string(), + 200, ); - // First few bars should return None (not enough data) - for i in 0..4 { - let bar = make_bar("AAPL", 100.0 + i as f64); + for i in 0..7 { + let bar = make_bar("GOOGL", 200.0 + i as f64); let signal = strategy.on_bar(&bar); assert!( signal.is_none(), @@ -48,23 +59,23 @@ fn test_momentum_needs_warmup() { #[test] fn test_momentum_generates_buy_signal() { let mut strategy = MomentumStrategy::new( - "test", - vec!["AAPL".to_string()], + "mid_momentum", + vec!["TSLA".to_string()], 5, - 0.02, // 2% threshold + 0.02, + "15Min".to_string(), + 75, ); - // Feed rising prices: 100, 101, 102, 103, 104, 105 for i in 0..6 { let price = 100.0 + i as f64; - let bar = make_bar("AAPL", price); + let bar = make_bar("TSLA", price); let signal = strategy.on_bar(&bar); if i >= 5 { - // After warmup, should generate buy signal (5% gain > 2% threshold) match signal { Some(Signal::Buy { symbol, strength }) => { - assert_eq!(symbol, "AAPL"); + assert_eq!(symbol, "TSLA"); assert!(strength > 0.0); } _ => panic!("Expected Buy signal, got {:?}", signal), @@ -75,19 +86,24 @@ fn test_momentum_generates_buy_signal() { #[test] fn test_momentum_generates_sell_signal() { - let mut strategy = MomentumStrategy::new("test", vec!["AAPL".to_string()], 5, 0.02); + let mut strategy = MomentumStrategy::new( + "conservative_momentum", + vec!["NVDA".to_string()], + 5, + 0.02, + "30Min".to_string(), + 120, + ); - // Feed falling prices: 110, 108, 106, 104, 102, 100 for i in 0..6 { let price = 110.0 - (i as f64 * 2.0); - let bar = make_bar("AAPL", price); + let bar = make_bar("NVDA", price); let signal = strategy.on_bar(&bar); if i >= 5 { - // After warmup, should generate sell signal match signal { Some(Signal::Sell { symbol, strength }) => { - assert_eq!(symbol, "AAPL"); + assert_eq!(symbol, "NVDA"); assert!(strength > 0.0); } _ => panic!("Expected Sell signal, got {:?}", signal), @@ -99,21 +115,21 @@ fn test_momentum_generates_sell_signal() { #[test] fn test_momentum_no_signal_in_range() { let mut strategy = MomentumStrategy::new( - "test", - vec!["AAPL".to_string()], - 5, - 0.10, // 10% threshold - high + "high_threshold", + vec!["AMZN".to_string()], + 4, + 0.10, + "1Hour".to_string(), + 50, ); - // Feed flat prices for _ in 0..10 { - let bar = make_bar("AAPL", 100.0); + let bar = make_bar("AMZN", 100.0); let signal = strategy.on_bar(&bar); - // Should be None or Hold (no significant movement) if let Some(sig) = signal { match sig { - Signal::Hold => {} // OK + Signal::Hold => {} Signal::Buy { .. } | Signal::Sell { .. } => { panic!("Should not generate Buy/Sell signal for flat prices"); } @@ -124,30 +140,39 @@ fn test_momentum_no_signal_in_range() { #[test] fn test_momentum_ignores_other_symbols() { - let mut strategy = MomentumStrategy::new("test", vec!["AAPL".to_string()], 5, 0.02); + let mut strategy = MomentumStrategy::new( + "aapl_only", + vec!["AAPL".to_string()], + 6, + 0.05, + "1Day".to_string(), + 30, + ); - // Feed bars for a different symbol let bar = make_bar("MSFT", 100.0); let signal = strategy.on_bar(&bar); - assert!(signal.is_none(), "Should ignore symbols not in strategy"); } #[test] fn test_strategy_reset() { - let mut strategy = MomentumStrategy::new("test", vec!["AAPL".to_string()], 5, 0.02); + let mut strategy = MomentumStrategy::new( + "reset_test", + vec!["META".to_string()], + 5, + 0.02, + "5Min".to_string(), + 60, + ); - // Feed some data for i in 0..10 { - let bar = make_bar("AAPL", 100.0 + i as f64); + let bar = make_bar("META", 100.0 + i as f64); strategy.on_bar(&bar); } - // Reset strategy.reset(); - // After reset, should need warmup again - let bar = make_bar("AAPL", 150.0); + let bar = make_bar("META", 150.0); let signal = strategy.on_bar(&bar); assert!(signal.is_none(), "Should need warmup after reset"); }