diff --git a/plan.md b/plan.md index a4e6550dd..0a6122b0d 100644 --- a/plan.md +++ b/plan.md @@ -605,11 +605,11 @@ def main(): ``` **Phase 5 Deliverables:** -- [ ] `trainer/runner.py` -- [ ] `thinker/runner.py` -- [ ] `trader/runner.py` -- [ ] `scripts/run_trainer.py`, `run_thinker.py`, `run_trader.py` -- [ ] Integration tests (runners with mock clients) +- [x] `trainer/runner.py` — orchestrates training across coins/timeframes with checkpoint resume and stop signal +- [x] `thinker/runner.py` — continuous signal generation with hot-reload of coin list +- [x] `trader/runner.py` — trade execution with position sync, entry/DCA/exit management +- [x] `scripts/run_trainer.py`, `run_thinker.py`, `run_trader.py` — entry points with dependency wiring +- [x] Integration tests (runners with mock clients) — 41 tests covering all three runners - [ ] Verify identical behavior to original scripts --- @@ -805,10 +805,10 @@ tests/ │ └── trainer/ │ └── test_memory.py # Memory I/O, checkpoints, distance, progress ✅ done (32 tests) — split in Phase 4 ├── integration/ -│ ├── test_trainer_runner.py # Full training with mock market ☐ Phase 5 -│ ├── test_thinker_runner.py # Signal gen with mock data ☐ Phase 5 -│ ├── test_trader_runner.py # Trade execution with paper client ☐ Phase 5 -│ └── test_file_ipc.py # End-to-end file-based communication ☐ Phase 5 +│ ├── test_trainer_runner.py # Full training with mock market ✅ Phase 5 +│ ├── test_thinker_runner.py # Signal gen with mock data ✅ Phase 5 +│ ├── test_trader_runner.py # Trade execution with paper client ✅ Phase 5 +│ └── test_file_ipc.py # End-to-end file-based communication (in test_trader_runner.py) ✅ Phase 5 └── conftest.py # Shared fixtures (mock clients, temp dirs) ✅ done ``` diff --git a/scripts/run_thinker.py b/scripts/run_thinker.py index 01d093f6c..935c9d415 100644 --- a/scripts/run_thinker.py +++ b/scripts/run_thinker.py @@ -1,11 +1,39 @@ #!/usr/bin/env python3 -"""Entry point for the PowerTrader Thinker / Signal Generator (thin wrapper).""" +"""Entry point for the PowerTrader Thinker / Signal Generator. + +Usage:: + + python scripts/run_thinker.py +""" + +from __future__ import annotations + +from pathlib import Path def main() -> None: - raise NotImplementedError( - "Thinker runner not yet migrated — use pt_thinker.py directly." + from powertrader.core.config import TradingConfig + from powertrader.core.constants import SETTINGS_FILENAME + from powertrader.core.logging_setup import setup_logger + from powertrader.core.market_client import KuCoinMarketClient + from powertrader.core.storage import FileStore + from powertrader.thinker.runner import ThinkerRunner + + base_dir = Path.cwd() + setup_logger("thinker", base_dir / "logs") + setup_logger("powertrader", base_dir / "logs") + + config = TradingConfig.from_file(base_dir / SETTINGS_FILENAME) + market = KuCoinMarketClient() + store = FileStore() + + runner = ThinkerRunner( + market=market, + config=config, + store=store, + base_dir=base_dir, ) + runner.run() if __name__ == "__main__": diff --git a/scripts/run_trader.py b/scripts/run_trader.py index 7bfec479e..d2ab6f945 100644 --- a/scripts/run_trader.py +++ b/scripts/run_trader.py @@ -1,11 +1,74 @@ #!/usr/bin/env python3 -"""Entry point for the PowerTrader Trade Executor (thin wrapper).""" +"""Entry point for the PowerTrader Trade Executor. + +Usage:: + + python scripts/run_trader.py # Live trading (Binance) + python scripts/run_trader.py --paper # Paper trading (simulated) +""" + +from __future__ import annotations + +import sys +from pathlib import Path def main() -> None: - raise NotImplementedError( - "Trader runner not yet migrated — use pt_trader.py directly." + from powertrader.core.config import TradingConfig + from powertrader.core.constants import SETTINGS_FILENAME + from powertrader.core.credentials import BinanceCredentials + from powertrader.core.logging_setup import setup_logger + from powertrader.core.storage import FileStore + from powertrader.core.trading_client import TradingClient + from powertrader.trader.dca_engine import DCAEngine + from powertrader.trader.entry_engine import EntryEngine + from powertrader.trader.runner import TraderRunner + from powertrader.trader.trailing_engine import TrailingProfitEngine + + base_dir = Path.cwd() + setup_logger("trader", base_dir / "logs") + setup_logger("powertrader", base_dir / "logs") + + config = TradingConfig.from_file(base_dir / SETTINGS_FILENAME) + store = FileStore() + + # Select trading client + paper_mode = "--paper" in sys.argv + client: TradingClient + + if paper_mode: + from powertrader.core.market_client import KuCoinMarketClient + from powertrader.core.paper_client import PaperTradingClient + + market = KuCoinMarketClient() + client = PaperTradingClient(market=market) + else: + from powertrader.core.trading_client import BinanceTradingClient + + creds = BinanceCredentials.load(base_dir) + if not creds.is_valid: + print("ERROR: No valid Binance credentials found.") + print( + "Set BINANCE_API_KEY/BINANCE_API_SECRET env vars or create b_key.txt/b_secret.txt" + ) + sys.exit(1) + client = BinanceTradingClient(creds) + + # Wire up engines + entry = EntryEngine(config) + dca = DCAEngine(config) + trailing = TrailingProfitEngine(config) + + runner = TraderRunner( + trading_client=client, + entry=entry, + dca=dca, + trailing=trailing, + config=config, + store=store, + base_dir=base_dir, ) + runner.run() if __name__ == "__main__": diff --git a/scripts/run_trainer.py b/scripts/run_trainer.py index 13f2aefbc..96e801769 100644 --- a/scripts/run_trainer.py +++ b/scripts/run_trainer.py @@ -1,11 +1,55 @@ #!/usr/bin/env python3 -"""Entry point for the PowerTrader Trainer (thin wrapper).""" +"""Entry point for the PowerTrader Trainer. + +Usage:: + + python scripts/run_trainer.py # Train all configured coins + python scripts/run_trainer.py BTC # Train a specific coin + python scripts/run_trainer.py ETH reprocess_yes # Retrain with full reprocessing +""" + +from __future__ import annotations + +import sys +from pathlib import Path def main() -> None: - raise NotImplementedError( - "Trainer runner not yet migrated — use pt_trainer.py directly." + from powertrader.core.config import TradingConfig + from powertrader.core.constants import SETTINGS_FILENAME + from powertrader.core.logging_setup import setup_logger + from powertrader.core.market_client import KuCoinMarketClient + from powertrader.core.storage import FileStore + from powertrader.trainer.runner import TrainerRunner + + base_dir = Path.cwd() + setup_logger("trainer", base_dir / "logs") + setup_logger("powertrader", base_dir / "logs") + + config = TradingConfig.from_file(base_dir / SETTINGS_FILENAME) + market = KuCoinMarketClient() + store = FileStore() + + # Parse CLI args: [coin] [reprocess_yes|reprocess_no] + coins: list[str] | None = None + reprocess = False + + args = sys.argv[1:] + for arg in args: + if arg.lower() in ("reprocess_yes", "reprocess"): + reprocess = True + elif arg.lower() == "reprocess_no": + reprocess = False + else: + coins = [arg.upper()] + + runner = TrainerRunner( + market=market, + config=config, + store=store, + base_dir=base_dir, ) + runner.run(coins=coins, reprocess=reprocess) if __name__ == "__main__": diff --git a/src/powertrader/thinker/runner.py b/src/powertrader/thinker/runner.py new file mode 100644 index 000000000..138b0a079 --- /dev/null +++ b/src/powertrader/thinker/runner.py @@ -0,0 +1,273 @@ +"""Thinker runner — continuous signal generation loop. + +Replaces the main loop from ``pt_thinker.py``. Iterates through all +configured coins, generates trading signals from trained pattern memories, +and writes signal files for the trader to consume. + +Supports hot-reload of the coin list from ``gui_settings.json``. +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path + +from powertrader.core.config import TradingConfig +from powertrader.core.constants import ( + SETTINGS_FILENAME, + TIMEFRAMES, + TRAINING_STALE_SECONDS, +) +from powertrader.core.market_client import MarketDataClient +from powertrader.core.paths import CoinPaths, build_coin_paths +from powertrader.core.storage import FileStore +from powertrader.models.memory import PatternMemory +from powertrader.models.signal import Signal +from powertrader.thinker.signal_engine import generate_signal + +logger = logging.getLogger(__name__) + +_LOOP_SLEEP_SECONDS = 0.15 +_TRAINING_TIME_FILENAME = "trainer_last_training_time.txt" + + +class ThinkerRunner: + """Continuous signal generation loop. + + Parameters + ---------- + market: + Market data client for fetching current prices. + config: + Trading configuration snapshot (used for initial coin list). + store: + File I/O abstraction. + base_dir: + Root project directory (where coin folders live). + """ + + def __init__( + self, + market: MarketDataClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + self._market = market + self._config = config + self._store = store + self._base_dir = base_dir + self._coins: list[str] = list(config.coins) + self._coin_paths: dict[str, CoinPaths] = build_coin_paths(base_dir, self._coins) + self._settings_mtime: float = 0.0 + self._running = True + + # -- public API ----------------------------------------------------------- + + def run(self) -> None: + """Main loop: generate signals for all coins, hot-reload config. + + Runs indefinitely until :meth:`stop` is called. + """ + logger.info("Thinker started for %d coins", len(self._coins)) + + while self._running: + self._sync_coins_from_settings() + self.step() + time.sleep(_LOOP_SLEEP_SECONDS) + + logger.info("Thinker stopped") + + def step(self) -> dict[str, Signal]: + """One iteration: process all coins once. + + Returns a ``{coin: Signal}`` dict of the generated signals. + """ + signals: dict[str, Signal] = {} + + for coin in self._coins: + paths = self._coin_paths.get(coin) + if paths is None: + continue + + try: + signal = self._step_coin(coin, paths) + if signal is not None: + signals[coin] = signal + self._write_signal_files(paths, signal) + except Exception as exc: + logger.error("Signal generation failed for %s: %s", coin, exc) + + return signals + + def stop(self) -> None: + """Request the runner to stop after the current iteration.""" + self._running = False + + # -- per-coin signal generation ------------------------------------------- + + def _step_coin(self, coin: str, paths: CoinPaths) -> Signal | None: + """Generate signal for one coin. + + Returns ``None`` if the coin is not trained or no data is available. + """ + # Training freshness gate + if not self._is_trained(paths): + self._write_zero_signals(paths, coin) + return None + + # Load all memory files across timeframes + memories = self._load_memories(paths) + if not memories: + self._write_zero_signals(paths, coin) + return None + + # Fetch current price + symbol = MarketDataClient.coin_to_kucoin_symbol(coin) + current_price = self._market.get_current_price(symbol) + if current_price <= 0: + logger.debug("No price data for %s", coin) + return None + + # Write current price file + self._store.write_signal(paths.current_price(), current_price) + + # Fetch latest candle for pattern matching (use 1hour) + candles = self._market.get_klines(symbol, "1hour", limit=2) + if not candles: + logger.debug("No candle data for %s", coin) + return None + + latest = candles[-1] + signal = generate_signal( + coin=coin, + current_price=current_price, + candle_open=latest.open, + candle_close=latest.close, + memories=memories, + ) + + logger.debug( + "Signal %s: LONG=%d SHORT=%d PM_L=%.2f PM_S=%.2f", + coin, + signal.long_level, + signal.short_level, + signal.long_profit_margin, + signal.short_profit_margin, + ) + return signal + + # -- memory loading ------------------------------------------------------- + + def _load_memories(self, paths: CoinPaths) -> dict[str, PatternMemory]: + """Load pattern memories for all timeframes.""" + memories: dict[str, PatternMemory] = {} + + for tf in TIMEFRAMES: + mem_path = paths.memory_file(tf) + if not mem_path.exists(): + continue + + mem_text = self._store.read_text(mem_path) + if not mem_text.strip(): + continue + + weights_text = self._store.read_text(paths.weight_file(tf)) + weights_high_text = self._store.read_text(paths.weight_high_file(tf)) + weights_low_text = self._store.read_text(paths.weight_low_file(tf)) + threshold = self._store.read_signal(paths.threshold_file(tf), default=1.0) + + memory = PatternMemory.from_memory_text( + mem_text, + weights_text=weights_text, + weights_high_text=weights_high_text, + weights_low_text=weights_low_text, + threshold=threshold, + ) + if not memory.is_empty: + memories[tf] = memory + + return memories + + # -- training freshness gate ---------------------------------------------- + + def _is_trained(self, paths: CoinPaths) -> bool: + """Check if training data is fresh enough to generate signals. + + Returns ``False`` if the training time file is missing or stale. + """ + time_path = paths.base / _TRAINING_TIME_FILENAME + if not time_path.exists(): + # If no training time file, check if any memory files exist + return any(paths.memory_file(tf).exists() for tf in TIMEFRAMES) + + raw = self._store.read_text(time_path).strip() + try: + last_train = float(raw) + except ValueError: + return False + + age = time.time() - last_train + return age < TRAINING_STALE_SECONDS + + # -- signal file writing -------------------------------------------------- + + def _write_signal_files(self, paths: CoinPaths, signal: Signal) -> None: + """Write signal files for the trader to consume.""" + self._store.write_int_signal(paths.signal_long(), signal.long_level) + self._store.write_int_signal(paths.signal_short(), signal.short_level) + self._store.write_signal(paths.profit_margin_long(), signal.long_profit_margin) + self._store.write_signal(paths.profit_margin_short(), signal.short_profit_margin) + + # Write bound prices (HTML format for hub display) + if signal.long_bounds: + self._store.write_text( + paths.bounds_low(), + " ".join(f"{b:.8f}" for b in signal.long_bounds), + ) + if signal.short_bounds: + self._store.write_text( + paths.bounds_high(), + " ".join(f"{b:.8f}" for b in signal.short_bounds), + ) + + def _write_zero_signals(self, paths: CoinPaths, coin: str) -> None: + """Write zero signals for an untrained or unavailable coin.""" + self._store.write_int_signal(paths.signal_long(), 0) + self._store.write_int_signal(paths.signal_short(), 0) + + # -- config hot-reload ---------------------------------------------------- + + def _sync_coins_from_settings(self) -> None: + """Hot-reload coin list from ``gui_settings.json`` if changed.""" + settings_path = self._base_dir / SETTINGS_FILENAME + if not settings_path.exists(): + return + + try: + mtime = settings_path.stat().st_mtime + except OSError: + return + + if mtime <= self._settings_mtime: + return # No change + + self._settings_mtime = mtime + new_config = TradingConfig.from_file(settings_path) + new_coins = list(new_config.coins) + + if new_coins == self._coins: + return # Same coin list + + added = [c for c in new_coins if c not in self._coins] + removed = [c for c in self._coins if c not in new_coins] + + if added: + logger.info("Coins added: %s", added) + if removed: + logger.info("Coins removed: %s", removed) + + self._coins = new_coins + self._coin_paths = build_coin_paths(self._base_dir, self._coins, create_missing=True) + self._config = new_config diff --git a/src/powertrader/trader/runner.py b/src/powertrader/trader/runner.py new file mode 100644 index 000000000..0d051dd35 --- /dev/null +++ b/src/powertrader/trader/runner.py @@ -0,0 +1,436 @@ +"""Trader runner — continuous trade execution loop. + +Replaces the main loop from ``pt_trader.py``. Reads signals from the +thinker, manages open positions, handles entries, DCA buys, and trailing +profit-margin exits. +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path + +from powertrader.core.config import TradingConfig +from powertrader.core.constants import QUOTE_ASSET +from powertrader.core.paths import CoinPaths, build_coin_paths +from powertrader.core.storage import FileStore +from powertrader.core.trading_client import TradingClient +from powertrader.models.position import Position +from powertrader.models.signal import Signal +from powertrader.models.trade import Trade +from powertrader.trader.dca_engine import DCAEngine +from powertrader.trader.entry_engine import EntryEngine +from powertrader.trader.trailing_engine import TrailingProfitEngine + +logger = logging.getLogger(__name__) + +_LOOP_SLEEP_SECONDS = 0.5 +_POST_TRADE_SLEEP_SECONDS = 5.0 +_STATUS_FILENAME = "trader_status.json" +_TRADE_HISTORY_FILENAME = "trade_history.jsonl" +_ACCOUNT_VALUE_FILENAME = "account_value_history.jsonl" +_HUB_DATA_DIR = "hub_data" + + +class TraderRunner: + """Continuous trade execution loop. + + Parameters + ---------- + trading_client: + Exchange client for placing orders. + entry: + Entry decision engine. + dca: + DCA decision engine. + trailing: + Trailing profit-margin exit engine. + config: + Trading configuration snapshot. + store: + File I/O abstraction. + base_dir: + Root project directory (where coin folders and hub_data live). + """ + + def __init__( + self, + trading_client: TradingClient, + entry: EntryEngine, + dca: DCAEngine, + trailing: TrailingProfitEngine, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + self._client = trading_client + self._entry = entry + self._dca = dca + self._trailing = trailing + self._config = config + self._store = store + self._base_dir = base_dir + self._hub_dir = base_dir / _HUB_DATA_DIR + self._coin_paths = build_coin_paths(base_dir, config.coins) + self._positions: dict[str, Position] = {} + self._running = True + + # -- public API ----------------------------------------------------------- + + def run(self) -> None: + """Main loop: manage positions, check entries, execute trades. + + Runs indefinitely until :meth:`stop` is called. + """ + logger.info("Trader started for %d coins", len(self._config.coins)) + + while self._running: + try: + self.step() + except Exception as exc: + logger.error("Trade management error: %s", exc, exc_info=True) + time.sleep(_LOOP_SLEEP_SECONDS) + + logger.info("Trader stopped") + + def step(self) -> None: + """One iteration: evaluate all positions and potential entries.""" + # Fetch current prices for all coins + prices = self._client.get_current_prices(list(self._coin_paths.keys())) + if not prices: + logger.debug("No prices available, skipping iteration") + return + + # Sync positions from exchange holdings + self._sync_positions(prices) + + # Calculate total account value + account_value = self._calculate_account_value(prices) + + # Manage existing positions (exits and DCA) + for coin in list(self._positions.keys()): + price = prices.get(coin) + if price is None or price <= 0: + continue + self._manage_position(coin, price) + + # Check for new entries + held_coins = set(self._positions.keys()) + for coin, paths in self._coin_paths.items(): + if coin in held_coins: + continue + price = prices.get(coin) + if price is None or price <= 0: + continue + self._check_entry(coin, paths, price, account_value) + + # Write status for hub GUI + self._write_status(prices, account_value) + + def stop(self) -> None: + """Request the runner to stop after the current iteration.""" + self._running = False + + # -- position sync -------------------------------------------------------- + + def _sync_positions(self, prices: dict[str, float]) -> None: + """Sync internal position state with exchange holdings. + + Detects new holdings (from manual trades) and removed holdings + (from external sells). + """ + try: + holdings = self._client.get_holdings() + except Exception as exc: + logger.error("Failed to fetch holdings: %s", exc) + return + + # Add newly detected positions + for coin, qty in holdings.items(): + if coin not in self._coin_paths: + continue # Not a tracked coin + if coin not in self._positions: + price = prices.get(coin, 0.0) + if price > 0 and qty > 0: + self._positions[coin] = Position( + coin=coin, + entry_price=price, + quantity=qty, + cost_basis_usd=qty * price, + ) + logger.info( + "Detected existing position: %s qty=%.8f price=%.4f", + coin, + qty, + price, + ) + + # Remove positions that are no longer held + for coin in list(self._positions.keys()): + if coin not in holdings or holdings[coin] <= 0: + logger.info("Position closed externally: %s", coin) + self._trailing.reset(coin) + del self._positions[coin] + + # -- position management -------------------------------------------------- + + def _manage_position(self, coin: str, current_price: float) -> None: + """Manage an existing position: check exit and DCA.""" + position = self._positions.get(coin) + if position is None: + return + + paths = self._coin_paths.get(coin) + if paths is None: + return + + # Read signals from thinker + signal = self._read_signals(coin, paths) + + # Check trailing exit BEFORE updating state — should_exit uses + # was_above from the *previous* tick's update_trailing call. + if self._trailing.should_exit(position, current_price): + self._execute_exit(coin, position, current_price) + return + + # Update trailing state (sets was_above for the *next* tick) + self._trailing.update_trailing(position, current_price) + + # Check DCA + should_buy, reason = self._dca.should_dca( + position, current_price, long_signal=signal.long_level + ) + if should_buy: + amount = self._dca.calculate_dca_amount(position, current_price) + self._execute_dca(coin, position, current_price, amount, reason) + + def _check_entry( + self, + coin: str, + paths: CoinPaths, + current_price: float, + account_value: float, + ) -> None: + """Check if we should enter a new position for this coin.""" + signal = self._read_signals(coin, paths) + + if not self._entry.should_enter(signal): + return + + entry_size = self._entry.calculate_entry_size(account_value) + if entry_size <= 0: + return + + logger.info( + "Entry signal for %s: LONG=%d SHORT=%d, size=$%.2f", + coin, + signal.long_level, + signal.short_level, + entry_size, + ) + + trade = self._client.market_buy(coin, entry_size) + if trade is None: + logger.error("Entry buy failed for %s", coin) + return + + # Create new position + self._positions[coin] = Position( + coin=coin, + entry_price=trade.price, + quantity=trade.quantity, + cost_basis_usd=trade.value, + ) + + self._record_trade(trade) + logger.info( + "Entered %s: qty=%.8f @ %.4f ($%.2f)", + coin, + trade.quantity, + trade.price, + trade.value, + ) + time.sleep(_POST_TRADE_SLEEP_SECONDS) + + # -- trade execution ------------------------------------------------------ + + def _execute_exit(self, coin: str, position: Position, current_price: float) -> None: + """Execute a trailing profit-margin exit.""" + pnl_pct = position.pnl_pct(current_price) + logger.info("Trailing exit for %s at %.4f (PnL=%.2f%%)", coin, current_price, pnl_pct) + + trade = self._client.market_sell(coin, position.quantity) + if trade is None: + logger.error("Exit sell failed for %s", coin) + return + + # Record with PnL + exit_trade = Trade( + coin=trade.coin, + side=trade.side, + price=trade.price, + quantity=trade.quantity, + value=trade.value, + reason="trailing_exit", + timestamp=trade.timestamp, + pnl_pct=pnl_pct, + order_id=trade.order_id, + ) + self._record_trade(exit_trade) + + # Clean up state + self._trailing.reset(coin) + self._dca.record_sell(coin) + del self._positions[coin] + + logger.info( + "Exited %s: qty=%.8f @ %.4f ($%.2f, PnL=%.2f%%)", + coin, + trade.quantity, + trade.price, + trade.value, + pnl_pct, + ) + time.sleep(_POST_TRADE_SLEEP_SECONDS) + + def _execute_dca( + self, + coin: str, + position: Position, + current_price: float, + amount: float, + reason: str, + ) -> None: + """Execute a DCA buy.""" + logger.info( + "DCA buy for %s: reason=%s, amount=$%.2f at %.4f", + coin, + reason, + amount, + current_price, + ) + + trade = self._client.market_buy(coin, amount) + if trade is None: + logger.error("DCA buy failed for %s (reason=%s)", coin, reason) + return + + # Update position + position.quantity += trade.quantity + position.cost_basis_usd += trade.value + position.dca_count += 1 + position.dca_timestamps.append(trade.timestamp) + + # Record DCA in rate limiter + self._dca.record_dca_buy(coin, trade.timestamp) + + # Reset trailing state after DCA (PM line changes) + self._trailing.reset(coin) + + # Record trade + dca_trade = Trade( + coin=trade.coin, + side=trade.side, + price=trade.price, + quantity=trade.quantity, + value=trade.value, + reason=reason, + timestamp=trade.timestamp, + order_id=trade.order_id, + ) + self._record_trade(dca_trade) + + logger.info( + "DCA %s: qty=%.8f @ %.4f ($%.2f), total_qty=%.8f, avg=%.4f", + coin, + trade.quantity, + trade.price, + trade.value, + position.quantity, + position.avg_price, + ) + time.sleep(_POST_TRADE_SLEEP_SECONDS) + + # -- signal reading ------------------------------------------------------- + + def _read_signals(self, coin: str, paths: CoinPaths) -> Signal: + """Read signal files written by the thinker.""" + long_level = self._store.read_int_signal(paths.signal_long(), default=0) + short_level = self._store.read_int_signal(paths.signal_short(), default=0) + long_pm = self._store.read_signal(paths.profit_margin_long(), default=0.0) + short_pm = self._store.read_signal(paths.profit_margin_short(), default=0.0) + + return Signal( + coin=coin, + long_level=long_level, + short_level=short_level, + long_profit_margin=long_pm, + short_profit_margin=short_pm, + timestamp=time.time(), + ) + + # -- account value -------------------------------------------------------- + + def _calculate_account_value(self, prices: dict[str, float]) -> float: + """Calculate total account value (USDT + holdings).""" + try: + balances = self._client.get_account_balance() + except Exception as exc: + logger.error("Failed to fetch account balance: %s", exc) + return 0.0 + + total = balances.get(QUOTE_ASSET, 0.0) + for coin, qty in balances.items(): + if coin == QUOTE_ASSET: + continue + price = prices.get(coin, 0.0) + total += qty * price + + return total + + # -- trade recording ------------------------------------------------------ + + def _record_trade(self, trade: Trade) -> None: + """Record a trade to the JSONL history file.""" + self._hub_dir.mkdir(parents=True, exist_ok=True) + self._store.append_jsonl( + self._hub_dir / _TRADE_HISTORY_FILENAME, + trade.to_dict(), + ) + + # -- status writing ------------------------------------------------------- + + def _write_status(self, prices: dict[str, float], account_value: float) -> None: + """Write trader status for the hub GUI.""" + self._hub_dir.mkdir(parents=True, exist_ok=True) + + positions_data: dict[str, object] = {} + for coin, pos in self._positions.items(): + price = prices.get(coin, 0.0) + trail_info = self._trailing.get_display_info(pos, price) + positions_data[coin] = { + "quantity": pos.quantity, + "avg_price": pos.avg_price, + "entry_price": pos.entry_price, + "current_price": price, + "pnl_pct": pos.pnl_pct(price), + "market_value": pos.market_value(price), + "dca_count": pos.dca_count, + **trail_info, + } + + status = { + "account_value": account_value, + "positions": positions_data, + "coins": list(self._coin_paths.keys()), + "timestamp": time.time(), + } + + self._store.write_json(self._hub_dir / _STATUS_FILENAME, status) + + # Append account value snapshot + self._store.append_jsonl( + self._hub_dir / _ACCOUNT_VALUE_FILENAME, + {"value": account_value, "timestamp": time.time()}, + ) diff --git a/src/powertrader/trainer/runner.py b/src/powertrader/trainer/runner.py new file mode 100644 index 000000000..413645468 --- /dev/null +++ b/src/powertrader/trainer/runner.py @@ -0,0 +1,314 @@ +"""Trainer runner — orchestrates training across all coins and timeframes. + +Replaces the main loop from ``pt_trainer.py``. For each coin, iterates +through all 7 timeframes, fetches historical candle data, builds pattern +memories, adjusts weights, and persists results to disk. + +Supports graceful stop via ``killer.txt`` and checkpoint-based resume. +""" + +from __future__ import annotations + +import contextlib +import logging +import time +from collections.abc import Callable +from pathlib import Path + +from powertrader.core.config import TradingConfig +from powertrader.core.constants import ( + KILLER_CHECK_INTERVAL, + KILLER_FILENAME, + TIMEFRAMES, + TRAINER_LOOKBACK_CANDLES, +) +from powertrader.core.market_client import MarketDataClient +from powertrader.core.paths import CoinPaths, build_coin_paths +from powertrader.core.storage import FileStore +from powertrader.models.memory import PatternMemory +from powertrader.trainer.training_engine import ( + adjust_weights, + build_patterns, + normalize_candles, +) + +logger = logging.getLogger(__name__) + +_CHECKPOINT_FILENAME = "trainer_checkpoint.json" +_STATUS_FILENAME = "trainer_status.json" + + +class TrainerRunner: + """Orchestrates training across all coins and timeframes. + + Parameters + ---------- + market: + Market data client for fetching historical candles. + config: + Trading configuration snapshot. + store: + File I/O abstraction. + base_dir: + Root project directory (where coin folders live). + on_progress: + Optional callback ``(coin, timeframe, position, total)`` for progress. + """ + + def __init__( + self, + market: MarketDataClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, + on_progress: Callable[[str, str, int, int], None] | None = None, + ) -> None: + self._market = market + self._config = config + self._store = store + self._base_dir = base_dir + self._on_progress = on_progress + self._coin_paths: dict[str, CoinPaths] = {} + + # -- public API ----------------------------------------------------------- + + def run( + self, + coins: list[str] | None = None, + reprocess: bool = False, + ) -> None: + """Train all configured coins sequentially. + + Parameters + ---------- + coins: + Specific coins to train. ``None`` uses ``config.coins``. + reprocess: + If ``True``, rebuild memories from scratch instead of + adjusting existing weights. + """ + coin_list = coins if coins is not None else list(self._config.coins) + self._coin_paths = build_coin_paths(self._base_dir, coin_list, create_missing=True) + + # Load checkpoint for resume + checkpoint = self._load_checkpoint() + start_coin = checkpoint.get("coin", "") + start_tf_idx = checkpoint.get("tf_index", 0) + resumed = bool(start_coin) + + self._write_status("TRAINING", coin="", timeframe="") + logger.info("Training started for %d coins (reprocess=%s)", len(coin_list), reprocess) + + for coin in coin_list: + if coin not in self._coin_paths: + logger.warning("Skipping %s: no coin folder found", coin) + continue + + # Resume logic: skip coins before the checkpoint + if resumed and coin != start_coin: + continue + resumed = False # Found checkpoint coin, start from here + + tf_start = start_tf_idx if coin == start_coin else 0 + start_coin = "" # Only apply resume offset once + + try: + self._train_coin(coin, reprocess=reprocess, tf_start=tf_start) + except _StopTrainingError: + logger.info("Training interrupted by stop signal at coin=%s", coin) + self._write_status("INTERRUPTED", coin=coin, timeframe="") + return + + # Training complete for all coins + self._clear_checkpoint() + self._write_status("FINISHED", coin="", timeframe="") + logger.info("Training complete for all %d coins", len(coin_list)) + + def should_stop(self) -> bool: + """Check ``killer.txt`` for a stop signal.""" + killer_path = self._base_dir / KILLER_FILENAME + content = self._store.read_text(killer_path).strip().lower() + return content == "yes" + + # -- per-coin training ---------------------------------------------------- + + def _train_coin( + self, + coin: str, + reprocess: bool = False, + tf_start: int = 0, + ) -> None: + """Train one coin across all timeframes.""" + paths = self._coin_paths[coin] + logger.info("Training %s (reprocess=%s, tf_start=%d)", coin, reprocess, tf_start) + + for tf_idx, timeframe in enumerate(TIMEFRAMES): + if tf_idx < tf_start: + continue + + self._write_status("TRAINING", coin=coin, timeframe=timeframe) + self._save_checkpoint(coin, tf_idx) + + # Check stop signal between timeframes + if self.should_stop(): + raise _StopTrainingError() + + try: + self._train_timeframe(coin, paths, timeframe, reprocess=reprocess) + except _StopTrainingError: + raise + except Exception as exc: + logger.error("Training %s/%s failed: %s", coin, timeframe, exc, exc_info=True) + # Continue with next timeframe on non-fatal errors + + logger.info("Training complete for %s", coin) + + def _train_timeframe( + self, + coin: str, + paths: CoinPaths, + timeframe: str, + reprocess: bool = False, + ) -> None: + """Train one coin on one timeframe.""" + logger.info("Training %s/%s — fetching history", coin, timeframe) + + # Fetch historical candle data + symbol = MarketDataClient.coin_to_kucoin_symbol(coin) + candles = self._market.get_all_klines( + symbol, timeframe, max_candles=TRAINER_LOOKBACK_CANDLES + ) + if not candles: + logger.warning("No candle data for %s/%s", coin, timeframe) + return + + logger.info("Fetched %d candles for %s/%s", len(candles), coin, timeframe) + + # Normalize candle data + close_pcts, high_pcts, low_pcts = normalize_candles(candles) + + if reprocess or not self._memory_exists(paths, timeframe): + # Build fresh memory from historical data + memory = build_patterns(close_pcts, high_pcts, low_pcts) + logger.info("Built %d patterns for %s/%s", memory.size, coin, timeframe) + else: + # Load existing memory and adjust weights + memory = self._load_memory(paths, timeframe) + if memory.is_empty: + memory = build_patterns(close_pcts, high_pcts, low_pcts) + logger.info( + "Rebuilt %d patterns for %s/%s (was empty)", + memory.size, + coin, + timeframe, + ) + + # Adjust weights with progress callback that checks stop signal + iteration_count = 0 + + def _progress(pos: int, total: int) -> None: + nonlocal iteration_count + iteration_count += 1 + + if self._on_progress: + self._on_progress(coin, timeframe, pos, total) + + # Periodically check stop signal during weight adjustment + if iteration_count % KILLER_CHECK_INTERVAL == 0 and self.should_stop(): + # Save progress before stopping + self._save_memory(paths, timeframe, memory) + raise _StopTrainingError() + + memory = adjust_weights(memory, close_pcts, high_pcts, low_pcts, on_progress=_progress) + + # Persist to disk + self._save_memory(paths, timeframe, memory) + logger.info( + "Saved memory for %s/%s (%d patterns, threshold=%.4f)", + coin, + timeframe, + memory.size, + memory.threshold, + ) + + # -- memory I/O ----------------------------------------------------------- + + def _memory_exists(self, paths: CoinPaths, timeframe: str) -> bool: + """Check if memory files exist for a timeframe.""" + return paths.memory_file(timeframe).exists() + + def _load_memory(self, paths: CoinPaths, timeframe: str) -> PatternMemory: + """Load pattern memory from disk.""" + mem_text = self._store.read_text(paths.memory_file(timeframe)) + weights_text = self._store.read_text(paths.weight_file(timeframe)) + weights_high_text = self._store.read_text(paths.weight_high_file(timeframe)) + weights_low_text = self._store.read_text(paths.weight_low_file(timeframe)) + threshold = self._store.read_signal(paths.threshold_file(timeframe), default=1.0) + + return PatternMemory.from_memory_text( + mem_text, + weights_text=weights_text, + weights_high_text=weights_high_text, + weights_low_text=weights_low_text, + threshold=threshold, + ) + + def _save_memory(self, paths: CoinPaths, timeframe: str, memory: PatternMemory) -> None: + """Persist pattern memory and weights to disk.""" + paths.ensure_dir() + + # Memory patterns (with high/low diffs embedded) + self._store.write_text(paths.memory_file(timeframe), memory.to_memory_text()) + + # Separate weight files (space-separated floats) + self._store.write_text( + paths.weight_file(timeframe), + " ".join(str(w) for w in memory.weights), + ) + self._store.write_text( + paths.weight_high_file(timeframe), + " ".join(str(w) for w in memory.weights_high), + ) + self._store.write_text( + paths.weight_low_file(timeframe), + " ".join(str(w) for w in memory.weights_low), + ) + + # Threshold + self._store.write_signal(paths.threshold_file(timeframe), memory.threshold) + + # -- checkpoint ----------------------------------------------------------- + + def _save_checkpoint(self, coin: str, tf_index: int) -> None: + """Save training progress for resume.""" + data = {"coin": coin, "tf_index": tf_index, "timestamp": time.time()} + self._store.write_json(self._base_dir / _CHECKPOINT_FILENAME, data) + + def _load_checkpoint(self) -> dict[str, object]: + """Load saved checkpoint, or return empty dict.""" + data = self._store.read_json(self._base_dir / _CHECKPOINT_FILENAME, default={}) + return data if isinstance(data, dict) else {} + + def _clear_checkpoint(self) -> None: + """Remove the checkpoint file after successful completion.""" + path = self._base_dir / _CHECKPOINT_FILENAME + with contextlib.suppress(OSError): + path.unlink(missing_ok=True) + + # -- status --------------------------------------------------------------- + + def _write_status(self, state: str, coin: str, timeframe: str) -> None: + """Write ``trainer_status.json`` for the hub GUI.""" + self._store.write_json( + self._base_dir / _STATUS_FILENAME, + { + "state": state, + "coin": coin, + "timeframe": timeframe, + "timestamp": time.time(), + }, + ) + + +class _StopTrainingError(Exception): + """Internal signal to unwind the training stack on graceful stop.""" diff --git a/tests/integration/test_thinker_runner.py b/tests/integration/test_thinker_runner.py new file mode 100644 index 000000000..3937b3314 --- /dev/null +++ b/tests/integration/test_thinker_runner.py @@ -0,0 +1,333 @@ +"""Integration tests for ThinkerRunner. + +Uses a mock market client and pre-built memory files to verify the +full signal generation pipeline without hitting real exchange APIs. +""" + +from __future__ import annotations + +import time +from pathlib import Path + +import pytest + +from powertrader.core.config import TradingConfig +from powertrader.core.constants import TIMEFRAMES +from powertrader.core.market_client import MarketDataClient +from powertrader.core.paths import CoinPaths +from powertrader.core.storage import FileStore +from powertrader.models.candle import Candle +from powertrader.models.memory import PatternMemory +from powertrader.thinker.runner import ThinkerRunner + +# --------------------------------------------------------------------------- +# Mock market client +# --------------------------------------------------------------------------- + + +class MockMarketClient(MarketDataClient): + """Returns deterministic data for testing.""" + + def __init__(self, price: float = 50000.0) -> None: + self._price = price + + def get_klines( + self, + symbol: str, + timeframe: str, + limit: int = 1500, + start_at: int | None = None, + end_at: int | None = None, + ) -> list[Candle]: + p = self._price + return [ + Candle( + timestamp=int(time.time()) - 7200, + open=p * 0.99, + close=p * 0.995, + high=p * 1.01, + low=p * 0.98, + volume=100.0, + ), + Candle( + timestamp=int(time.time()) - 3600, + open=p * 0.995, + close=p, + high=p * 1.005, + low=p * 0.99, + volume=150.0, + ), + ] + + def get_current_price(self, symbol: str) -> float: + return self._price + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _write_simple_memory(store: FileStore, paths: CoinPaths) -> None: + """Write a simple memory with known patterns to a coin folder.""" + # Create a memory with a few patterns + patterns = [[0.5], [1.0], [-0.5], [0.2]] + high_diffs = [0.02, 0.03, 0.01, 0.015] + low_diffs = [-0.01, -0.02, -0.005, -0.01] + weights = [1.0, 1.0, 1.0, 1.0] + + memory = PatternMemory( + patterns=patterns, + high_diffs=high_diffs, + low_diffs=low_diffs, + weights=weights, + weights_high=list(weights), + weights_low=list(weights), + threshold=50.0, # Very permissive to ensure matches + ) + + paths.ensure_dir() + for tf in TIMEFRAMES: + store.write_text(paths.memory_file(tf), memory.to_memory_text()) + store.write_text(paths.weight_file(tf), " ".join(str(w) for w in memory.weights)) + store.write_text( + paths.weight_high_file(tf), + " ".join(str(w) for w in memory.weights_high), + ) + store.write_text( + paths.weight_low_file(tf), + " ".join(str(w) for w in memory.weights_low), + ) + store.write_signal(paths.threshold_file(tf), memory.threshold) + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture +def config() -> TradingConfig: + return TradingConfig(coins=["BTC", "ETH"]) + + +@pytest.fixture +def store() -> FileStore: + return FileStore() + + +@pytest.fixture +def market() -> MockMarketClient: + return MockMarketClient(price=50000.0) + + +@pytest.fixture +def runner_with_memories( + market: MockMarketClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, +) -> ThinkerRunner: + """ThinkerRunner with pre-built memory files for BTC and ETH.""" + # Create BTC memories (in root) + btc_paths = CoinPaths(base_dir, "BTC") + _write_simple_memory(store, btc_paths) + + # Create ETH memories (in subfolder) + eth_paths = CoinPaths(base_dir, "ETH") + _write_simple_memory(store, eth_paths) + + return ThinkerRunner(market=market, config=config, store=store, base_dir=base_dir) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestThinkerRunnerStep: + """Test single-step signal generation.""" + + def test_generates_signals_for_all_coins(self, runner_with_memories: ThinkerRunner) -> None: + """A single step should generate signals for all configured coins.""" + signals = runner_with_memories.step() + assert "BTC" in signals + assert "ETH" in signals + + def test_signal_levels_are_valid(self, runner_with_memories: ThinkerRunner) -> None: + """Signal levels should be in the 0-7 range.""" + signals = runner_with_memories.step() + for coin, signal in signals.items(): + assert 0 <= signal.long_level <= 7, f"{coin} long_level={signal.long_level}" + assert 0 <= signal.short_level <= 7, f"{coin} short_level={signal.short_level}" + + def test_writes_signal_files( + self, + runner_with_memories: ThinkerRunner, + base_dir: Path, + store: FileStore, + ) -> None: + """Step should write signal files for the trader.""" + runner_with_memories.step() + + # BTC signal files in root + assert (base_dir / "long_dca_signal.txt").exists() + assert (base_dir / "short_dca_signal.txt").exists() + + # Values should be integers + long_val = store.read_int_signal(base_dir / "long_dca_signal.txt") + assert 0 <= long_val <= 7 + + def test_writes_profit_margin_files( + self, + runner_with_memories: ThinkerRunner, + base_dir: Path, + store: FileStore, + ) -> None: + """Step should write profit margin files.""" + runner_with_memories.step() + + assert (base_dir / "futures_long_profit_margin.txt").exists() + assert (base_dir / "futures_short_profit_margin.txt").exists() + + def test_writes_current_price( + self, + runner_with_memories: ThinkerRunner, + base_dir: Path, + store: FileStore, + ) -> None: + """Step should write the current price file.""" + runner_with_memories.step() + + price = store.read_signal(base_dir / "BTC_current_price.txt") + assert price == 50000.0 + + +class TestThinkerRunnerNoMemory: + """Test behavior without trained memories.""" + + def test_untrained_coin_gets_zero_signals( + self, + market: MockMarketClient, + store: FileStore, + base_dir: Path, + ) -> None: + """Coins without memory files should get zero signals.""" + config = TradingConfig(coins=["BTC"]) + runner = ThinkerRunner(market=market, config=config, store=store, base_dir=base_dir) + signals = runner.step() + + # BTC has no memory files, so should return None (not in dict) + assert "BTC" not in signals + + # Zero signals should be written + val = store.read_int_signal(base_dir / "long_dca_signal.txt") + assert val == 0 + + +class TestThinkerRunnerHotReload: + """Test config hot-reload.""" + + def test_detects_added_coin( + self, + runner_with_memories: ThinkerRunner, + base_dir: Path, + store: FileStore, + ) -> None: + """Should detect a new coin added to gui_settings.json.""" + # Initial step + runner_with_memories.step() + + # Write new settings with an extra coin + new_settings = { + "coins": ["BTC", "ETH", "XRP"], + } + store.write_json(base_dir / "gui_settings.json", new_settings) + + # Force mtime detection + runner_with_memories._settings_mtime = 0 + + # Trigger hot-reload + runner_with_memories._sync_coins_from_settings() + + # XRP should now be in the coin list + assert "XRP" in runner_with_memories._coins + + def test_detects_removed_coin( + self, + runner_with_memories: ThinkerRunner, + base_dir: Path, + store: FileStore, + ) -> None: + """Should detect a coin removed from gui_settings.json.""" + # Write settings with fewer coins + new_settings = {"coins": ["BTC"]} + store.write_json(base_dir / "gui_settings.json", new_settings) + runner_with_memories._settings_mtime = 0 + + runner_with_memories._sync_coins_from_settings() + + assert "ETH" not in runner_with_memories._coins + assert "BTC" in runner_with_memories._coins + + +class TestThinkerRunnerStop: + """Test stop mechanism.""" + + def test_stop_flag(self, runner_with_memories: ThinkerRunner) -> None: + """Setting stop should break the main loop.""" + runner_with_memories.stop() + assert runner_with_memories._running is False + + +class TestThinkerRunnerEdgeCases: + """Test edge cases.""" + + def test_zero_price_skips_coin( + self, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """Coins with zero price should be skipped.""" + + class ZeroPriceMarket(MockMarketClient): + def get_current_price(self, symbol: str) -> float: + return 0.0 + + btc_paths = CoinPaths(base_dir, "BTC") + _write_simple_memory(store, btc_paths) + + runner = ThinkerRunner( + market=ZeroPriceMarket(), + config=TradingConfig(coins=["BTC"]), + store=store, + base_dir=base_dir, + ) + signals = runner.step() + assert "BTC" not in signals + + def test_api_error_handled_gracefully( + self, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """API errors should not crash the runner.""" + + class ErrorMarket(MockMarketClient): + def get_current_price(self, symbol: str) -> float: + raise ConnectionError("Network error") + + btc_paths = CoinPaths(base_dir, "BTC") + _write_simple_memory(store, btc_paths) + + runner = ThinkerRunner( + market=ErrorMarket(), + config=TradingConfig(coins=["BTC"]), + store=store, + base_dir=base_dir, + ) + # Should not raise + signals = runner.step() + assert "BTC" not in signals diff --git a/tests/integration/test_trader_runner.py b/tests/integration/test_trader_runner.py new file mode 100644 index 000000000..7ee276d7e --- /dev/null +++ b/tests/integration/test_trader_runner.py @@ -0,0 +1,492 @@ +"""Integration tests for TraderRunner. + +Uses a mock trading client to verify the full trade management pipeline +without placing real orders. +""" + +from __future__ import annotations + +import time +from pathlib import Path + +import pytest + +from powertrader.core.config import TradingConfig +from powertrader.core.paths import CoinPaths +from powertrader.core.storage import FileStore +from powertrader.core.trading_client import TradingClient +from powertrader.models.position import Position +from powertrader.models.trade import Trade +from powertrader.trader.dca_engine import DCAEngine +from powertrader.trader.entry_engine import EntryEngine +from powertrader.trader.runner import TraderRunner +from powertrader.trader.trailing_engine import TrailingProfitEngine + +# --------------------------------------------------------------------------- +# Mock trading client +# --------------------------------------------------------------------------- + + +class MockTradingClient(TradingClient): + """Records all operations for assertion.""" + + def __init__( + self, + balance: float = 10000.0, + prices: dict[str, float] | None = None, + holdings: dict[str, float] | None = None, + ) -> None: + self._balance = balance + self._prices = prices or {} + self._holdings = dict(holdings or {}) + self.buy_calls: list[tuple[str, float]] = [] + self.sell_calls: list[tuple[str, float]] = [] + + def get_account_balance(self) -> dict[str, float]: + result: dict[str, float] = {"USDT": self._balance} + for coin, qty in self._holdings.items(): + result[coin] = qty + return result + + def get_holdings(self) -> dict[str, float]: + return {c: q for c, q in self._holdings.items() if q > 0} + + def market_buy(self, coin: str, quote_amount: float) -> Trade | None: + self.buy_calls.append((coin, quote_amount)) + price = self._prices.get(coin, 50000.0) + qty = quote_amount / price + self._holdings[coin] = self._holdings.get(coin, 0.0) + qty + self._balance -= quote_amount + return Trade( + coin=coin, + side="BUY", + price=price, + quantity=qty, + value=quote_amount, + reason="entry", + timestamp=time.time(), + ) + + def market_sell(self, coin: str, quantity: float) -> Trade | None: + self.sell_calls.append((coin, quantity)) + price = self._prices.get(coin, 50000.0) + value = quantity * price + self._holdings[coin] = max(0.0, self._holdings.get(coin, 0.0) - quantity) + self._balance += value + return Trade( + coin=coin, + side="SELL", + price=price, + quantity=quantity, + value=value, + reason="exit", + timestamp=time.time(), + ) + + def get_current_prices(self, coins: list[str]) -> dict[str, float]: + return {c: self._prices[c] for c in coins if c in self._prices} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + (tmp_path / "ETH").mkdir() + (tmp_path / "hub_data").mkdir() + return tmp_path + + +@pytest.fixture +def config() -> TradingConfig: + return TradingConfig(coins=["BTC", "ETH"]) + + +@pytest.fixture +def store() -> FileStore: + return FileStore() + + +def _write_signals( + store: FileStore, + paths: CoinPaths, + long_level: int = 0, + short_level: int = 0, + long_pm: float = 0.25, + short_pm: float = 0.25, +) -> None: + """Write signal files for a coin.""" + paths.ensure_dir() + store.write_int_signal(paths.signal_long(), long_level) + store.write_int_signal(paths.signal_short(), short_level) + store.write_signal(paths.profit_margin_long(), long_pm) + store.write_signal(paths.profit_margin_short(), short_pm) + + +def _make_runner( + client: MockTradingClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, +) -> TraderRunner: + """Create a TraderRunner with all engines wired up.""" + entry = EntryEngine(config) + dca = DCAEngine(config) + trailing = TrailingProfitEngine(config) + return TraderRunner( + trading_client=client, + entry=entry, + dca=dca, + trailing=trailing, + config=config, + store=store, + base_dir=base_dir, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTraderRunnerEntry: + """Test trade entry logic.""" + + def test_enters_on_strong_long_signal( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should enter when LONG >= 3 and SHORT == 0.""" + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0, "ETH": 3000.0}) + runner = _make_runner(client, config, store, base_dir) + + # Write strong LONG signal for BTC + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=5, short_level=0) + + # Neutral for ETH + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=0, short_level=0) + + runner.step() + + # Should have placed a buy for BTC + assert len(client.buy_calls) == 1 + assert client.buy_calls[0][0] == "BTC" + + def test_no_entry_on_weak_signal( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should NOT enter when LONG < trade_start_level.""" + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0, "ETH": 3000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=2, short_level=0) + + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=1, short_level=0) + + runner.step() + + assert len(client.buy_calls) == 0 + + def test_no_entry_with_short_signal( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should NOT enter when SHORT > 0 even if LONG is high.""" + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0, "ETH": 3000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=5, short_level=1) + + runner.step() + + assert len(client.buy_calls) == 0 + + def test_entry_size_matches_config(self, store: FileStore, base_dir: Path) -> None: + """Entry size should be account_value * start_allocation_pct.""" + config = TradingConfig(coins=["BTC"], start_allocation_pct=0.01) # 1% of account + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=5, short_level=0) + + runner.step() + + assert len(client.buy_calls) == 1 + # ~$100 = 10000 * 0.01 (exact value depends on account value calculation) + buy_amount = client.buy_calls[0][1] + assert buy_amount > 0 + + +class TestTraderRunnerExit: + """Test trailing profit-margin exit.""" + + def test_exit_on_trailing_crossover(self, store: FileStore, base_dir: Path) -> None: + """Should sell when price crosses below trailing line.""" + config = TradingConfig( + coins=["BTC"], + pm_start_pct_no_dca=5.0, + trailing_gap_pct=0.5, + ) + # Start with BTC holding, price above PM line + client = MockTradingClient( + balance=5000.0, + prices={"BTC": 52500.0}, # 5% above entry + holdings={"BTC": 0.001}, + ) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=3, short_level=0) + + # Manually inject a position with known cost basis + runner._positions["BTC"] = Position( + coin="BTC", + entry_price=50000.0, + quantity=0.001, + cost_basis_usd=50.0, + ) + + # Step 1: Price at 52500 (5% above entry) — activates trailing + runner.step() + assert len(client.sell_calls) == 0 # Not yet — just activated + + # Step 2: Price rises to 53000 — peak tracking + client._prices["BTC"] = 53000.0 + runner.step() + assert len(client.sell_calls) == 0 + + # Step 3: Price drops below trailing line (53000 * 0.995 = 52735) + client._prices["BTC"] = 52700.0 + runner.step() + + # Should have sold + assert len(client.sell_calls) == 1 + assert client.sell_calls[0][0] == "BTC" + + +class TestTraderRunnerDCA: + """Test DCA (dollar cost averaging) logic.""" + + def test_dca_on_hard_threshold(self, store: FileStore, base_dir: Path) -> None: + """Should DCA when PnL drops below hard threshold.""" + config = TradingConfig( + coins=["BTC"], + dca_levels=[-2.5, -5.0, -10.0], + dca_multiplier=2.0, + max_dca_buys_per_24h=2, + ) + # Price dropped 3% from entry + client = MockTradingClient( + balance=5000.0, + prices={"BTC": 48500.0}, + holdings={"BTC": 0.001}, + ) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=3, short_level=0) + + # Position with entry at 50000 + runner._positions["BTC"] = Position( + coin="BTC", + entry_price=50000.0, + quantity=0.001, + cost_basis_usd=50.0, # $50 spent + ) + + runner.step() + + # Should have placed a DCA buy (-3% < -2.5% threshold) + assert len(client.buy_calls) >= 1 + buy_coin, buy_amount = client.buy_calls[0] + assert buy_coin == "BTC" + assert buy_amount > 0 + + +class TestTraderRunnerPositionSync: + """Test position syncing with exchange.""" + + def test_detects_new_holdings( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should detect holdings from the exchange and create positions.""" + client = MockTradingClient( + balance=9000.0, + prices={"BTC": 50000.0, "ETH": 3000.0}, + holdings={"BTC": 0.01}, + ) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=0, short_level=0) + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=0, short_level=0) + + runner.step() + + # Should have detected BTC position + assert "BTC" in runner._positions + assert runner._positions["BTC"].quantity == 0.01 + + def test_removes_closed_positions( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should remove positions that are no longer held.""" + client = MockTradingClient( + balance=10000.0, + prices={"BTC": 50000.0, "ETH": 3000.0}, + holdings={}, # No holdings + ) + runner = _make_runner(client, config, store, base_dir) + + # Inject a stale position + runner._positions["BTC"] = Position( + coin="BTC", entry_price=50000.0, quantity=0.001, cost_basis_usd=50.0 + ) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=0, short_level=0) + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=0, short_level=0) + + runner.step() + + # Position should be removed + assert "BTC" not in runner._positions + + +class TestTraderRunnerStatusOutput: + """Test status file output.""" + + def test_writes_trader_status( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should write trader_status.json.""" + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0, "ETH": 3000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=0, short_level=0) + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=0, short_level=0) + + runner.step() + + status_path = base_dir / "hub_data" / "trader_status.json" + status = store.read_json(status_path) + assert status is not None + assert "account_value" in status + assert "positions" in status + assert "coins" in status + assert status["account_value"] > 0 + + def test_writes_account_value_history( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should append to account_value_history.jsonl.""" + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0, "ETH": 3000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=0, short_level=0) + eth_paths = CoinPaths(base_dir, "ETH") + _write_signals(store, eth_paths, long_level=0, short_level=0) + + runner.step() + runner.step() + + history_path = base_dir / "hub_data" / "account_value_history.jsonl" + assert history_path.exists() + lines = history_path.read_text().strip().split("\n") + assert len(lines) >= 2 # At least 2 snapshots + + def test_records_trades(self, store: FileStore, base_dir: Path) -> None: + """Executed trades should be appended to trade_history.jsonl.""" + config = TradingConfig(coins=["BTC"]) + client = MockTradingClient(balance=10000.0, prices={"BTC": 50000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=5, short_level=0) + + runner.step() + + trade_path = base_dir / "hub_data" / "trade_history.jsonl" + assert trade_path.exists() + lines = trade_path.read_text().strip().split("\n") + assert len(lines) >= 1 + + +class TestTraderRunnerStop: + """Test stop mechanism.""" + + def test_stop_flag(self, config: TradingConfig, store: FileStore, base_dir: Path) -> None: + """Setting stop should break the main loop.""" + client = MockTradingClient(balance=10000.0, prices={}) + runner = _make_runner(client, config, store, base_dir) + runner.stop() + assert runner._running is False + + +class TestTraderRunnerEdgeCases: + """Test edge cases.""" + + def test_no_prices_skips_iteration( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should handle missing prices gracefully.""" + client = MockTradingClient(balance=10000.0, prices={}) + runner = _make_runner(client, config, store, base_dir) + + # Should not raise + runner.step() + assert len(client.buy_calls) == 0 + + def test_failed_buy_handled( + self, config: TradingConfig, store: FileStore, base_dir: Path + ) -> None: + """Should handle failed buy orders gracefully.""" + + class FailingClient(MockTradingClient): + def market_buy(self, coin: str, quote_amount: float) -> Trade | None: + return None # Simulates failure + + client = FailingClient(balance=10000.0, prices={"BTC": 50000.0}) + runner = _make_runner(client, config, store, base_dir) + + btc_paths = CoinPaths(base_dir, "BTC") + _write_signals(store, btc_paths, long_level=5, short_level=0) + + # Should not raise + runner.step() + assert "BTC" not in runner._positions + + +class TestFileIPC: + """Test file-based inter-process communication between thinker and trader.""" + + def test_signal_files_roundtrip(self, store: FileStore, base_dir: Path) -> None: + """Signal files written by thinker format should be readable by trader.""" + paths = CoinPaths(base_dir, "BTC") + paths.ensure_dir() + + # Write signals (as thinker would) + store.write_int_signal(paths.signal_long(), 5) + store.write_int_signal(paths.signal_short(), 0) + store.write_signal(paths.profit_margin_long(), 2.5) + store.write_signal(paths.profit_margin_short(), 0.0) + + # Read signals (as trader would) + long_val = store.read_int_signal(paths.signal_long()) + short_val = store.read_int_signal(paths.signal_short()) + long_pm = store.read_signal(paths.profit_margin_long()) + + assert long_val == 5 + assert short_val == 0 + assert long_pm == 2.5 diff --git a/tests/integration/test_trainer_runner.py b/tests/integration/test_trainer_runner.py new file mode 100644 index 000000000..210dce751 --- /dev/null +++ b/tests/integration/test_trainer_runner.py @@ -0,0 +1,365 @@ +"""Integration tests for TrainerRunner. + +Uses a mock market client to verify the full training pipeline without +hitting real exchange APIs. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from powertrader.core.config import TradingConfig +from powertrader.core.constants import KILLER_FILENAME, TIMEFRAMES +from powertrader.core.market_client import MarketDataClient +from powertrader.core.storage import FileStore +from powertrader.models.candle import Candle +from powertrader.trainer.runner import TrainerRunner + +# --------------------------------------------------------------------------- +# Mock market client +# --------------------------------------------------------------------------- + + +class MockMarketClient(MarketDataClient): + """Returns deterministic candle data for testing.""" + + def __init__(self, candle_count: int = 50) -> None: + self._candle_count = candle_count + self.call_count = 0 + + def get_klines( + self, + symbol: str, + timeframe: str, + limit: int = 1500, + start_at: int | None = None, + end_at: int | None = None, + ) -> list[Candle]: + self.call_count += 1 + return self._make_candles(self._candle_count) + + def get_current_price(self, symbol: str) -> float: + return 50000.0 + + def get_all_klines( + self, + symbol: str, + timeframe: str, + max_candles: int = 100_000, + ) -> list[Candle]: + self.call_count += 1 + return self._make_candles(min(self._candle_count, max_candles)) + + @staticmethod + def _make_candles(count: int) -> list[Candle]: + """Generate deterministic candle data with upward trend.""" + candles = [] + base = 50000.0 + for i in range(count): + o = base + i * 10.0 + c = o + 5.0 + (i % 3) + h = max(o, c) + 20.0 + l = min(o, c) - 15.0 # noqa: E741 + candles.append( + Candle( + timestamp=1700000000 + i * 3600, + open=o, + close=c, + high=h, + low=l, + volume=100.0 + i, + ) + ) + return candles + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + """Project root with coin folder structure.""" + # BTC uses root + (tmp_path / "ETH").mkdir() + return tmp_path + + +@pytest.fixture +def config() -> TradingConfig: + return TradingConfig(coins=["BTC", "ETH"]) + + +@pytest.fixture +def store() -> FileStore: + return FileStore() + + +@pytest.fixture +def market() -> MockMarketClient: + return MockMarketClient(candle_count=30) + + +@pytest.fixture +def runner( + market: MockMarketClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, +) -> TrainerRunner: + return TrainerRunner(market=market, config=config, store=store, base_dir=base_dir) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTrainerRunnerRun: + """Test the full training pipeline.""" + + def test_trains_all_coins(self, runner: TrainerRunner, base_dir: Path) -> None: + """Should train all configured coins and create memory files.""" + runner.run() + + # BTC memory files in root + for tf in TIMEFRAMES: + assert (base_dir / f"memories_{tf}.txt").exists() + assert (base_dir / f"memory_weights_{tf}.txt").exists() + + # ETH memory files in subfolder + for tf in TIMEFRAMES: + assert (base_dir / "ETH" / f"memories_{tf}.txt").exists() + + def test_trains_single_coin(self, runner: TrainerRunner, base_dir: Path) -> None: + """Should train only the specified coin.""" + runner.run(coins=["BTC"]) + + # BTC should have memory files + assert (base_dir / "memories_1hour.txt").exists() + + # ETH should NOT (unless it existed before) + assert not (base_dir / "ETH" / "memories_1hour.txt").exists() + + def test_writes_status_file( + self, runner: TrainerRunner, base_dir: Path, store: FileStore + ) -> None: + """Should write trainer_status.json.""" + runner.run() + + status = store.read_json(base_dir / "trainer_status.json") + assert status is not None + assert status["state"] == "FINISHED" + + def test_clears_checkpoint_on_completion(self, runner: TrainerRunner, base_dir: Path) -> None: + """Checkpoint should be removed after successful training.""" + runner.run() + assert not (base_dir / "trainer_checkpoint.json").exists() + + def test_reprocess_rebuilds_memory( + self, runner: TrainerRunner, base_dir: Path, store: FileStore + ) -> None: + """Reprocess should rebuild memory from scratch.""" + # First run + runner.run(coins=["BTC"]) + first_mem = store.read_text(base_dir / "memories_1hour.txt") + + # Second run without reprocess should adjust existing + runner.run(coins=["BTC"]) + adjusted_mem = store.read_text(base_dir / "memories_1hour.txt") + # Patterns should be same (adjusting weights, not patterns) + assert first_mem.count("~") == adjusted_mem.count("~") + + # Reprocess should rebuild + runner.run(coins=["BTC"], reprocess=True) + reprocessed_mem = store.read_text(base_dir / "memories_1hour.txt") + assert reprocessed_mem # Should have content + + def test_memory_files_have_content( + self, runner: TrainerRunner, base_dir: Path, store: FileStore + ) -> None: + """Memory files should contain actual pattern data.""" + runner.run(coins=["BTC"]) + + mem = store.read_text(base_dir / "memories_1hour.txt") + assert "~" in mem or mem.strip() # At least one pattern + + weights = store.read_text(base_dir / "memory_weights_1hour.txt") + assert weights.strip() # Should have weight values + + threshold = store.read_signal(base_dir / "neural_perfect_threshold_1hour.txt") + assert threshold > 0 + + +class TestTrainerRunnerStopSignal: + """Test graceful stop via killer.txt.""" + + def test_should_stop_when_killer_says_yes(self, runner: TrainerRunner, base_dir: Path) -> None: + (base_dir / KILLER_FILENAME).write_text("yes", encoding="utf-8") + assert runner.should_stop() is True + + def test_should_not_stop_when_killer_says_no( + self, runner: TrainerRunner, base_dir: Path + ) -> None: + (base_dir / KILLER_FILENAME).write_text("no", encoding="utf-8") + assert runner.should_stop() is False + + def test_should_not_stop_when_killer_missing(self, runner: TrainerRunner) -> None: + assert runner.should_stop() is False + + def test_stop_writes_interrupted_status( + self, + market: MockMarketClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """Stopping mid-training should write INTERRUPTED status.""" + # Write killer file before starting + (base_dir / KILLER_FILENAME).write_text("yes", encoding="utf-8") + + runner = TrainerRunner(market=market, config=config, store=store, base_dir=base_dir) + runner.run() + + status = store.read_json(base_dir / "trainer_status.json") + assert status["state"] == "INTERRUPTED" + + +class TestTrainerRunnerCheckpoint: + """Test checkpoint-based resume.""" + + def test_saves_checkpoint_during_training( + self, + market: MockMarketClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """Write killer mid-way to capture a checkpoint.""" + call_count = 0 + + class StoppingMarket(MockMarketClient): + def get_all_klines( + self, symbol: str, timeframe: str, max_candles: int = 100_000 + ) -> list[Candle]: + nonlocal call_count + call_count += 1 + if call_count > 2: + # Trigger stop after 2 timeframes + (base_dir / KILLER_FILENAME).write_text("yes") + return super().get_all_klines(symbol, timeframe, max_candles) + + runner = TrainerRunner( + market=StoppingMarket(candle_count=30), + config=TradingConfig(coins=["BTC"]), + store=store, + base_dir=base_dir, + ) + runner.run() + + # Status should be INTERRUPTED + status = store.read_json(base_dir / "trainer_status.json") + assert status["state"] == "INTERRUPTED" + + def test_resume_from_checkpoint( + self, + market: MockMarketClient, + store: FileStore, + base_dir: Path, + ) -> None: + """Write a checkpoint and verify runner resumes from it.""" + # Pre-write a checkpoint that says we left off at BTC, tf_index=3 + store.write_json( + base_dir / "trainer_checkpoint.json", + {"coin": "BTC", "tf_index": 3}, + ) + + runner = TrainerRunner( + market=market, + config=TradingConfig(coins=["BTC"]), + store=store, + base_dir=base_dir, + ) + runner.run() + + # Should have completed training (checkpoint cleared) + assert not (base_dir / "trainer_checkpoint.json").exists() + + # Should have memory files for timeframes starting at index 3+ + for tf in TIMEFRAMES[3:]: + assert (base_dir / f"memories_{tf}.txt").exists() + + +class TestTrainerRunnerEdgeCases: + """Test edge cases.""" + + def test_empty_candle_data( + self, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """Should handle empty candle data gracefully.""" + + class EmptyMarket(MockMarketClient): + def get_all_klines( + self, symbol: str, timeframe: str, max_candles: int = 100_000 + ) -> list[Candle]: + return [] + + runner = TrainerRunner( + market=EmptyMarket(), + config=config, + store=store, + base_dir=base_dir, + ) + runner.run() + + # Should complete without errors + status = store.read_json(base_dir / "trainer_status.json") + assert status["state"] == "FINISHED" + + def test_missing_coin_folder( + self, + market: MockMarketClient, + store: FileStore, + base_dir: Path, + ) -> None: + """Should skip coins without folders (non-BTC).""" + config = TradingConfig(coins=["BTC", "NONEXISTENT"]) + runner = TrainerRunner(market=market, config=config, store=store, base_dir=base_dir) + runner.run() + + # BTC should still be trained + assert (base_dir / "memories_1hour.txt").exists() + + def test_progress_callback( + self, + market: MockMarketClient, + config: TradingConfig, + store: FileStore, + base_dir: Path, + ) -> None: + """Progress callback should be called during training.""" + progress_calls: list[tuple[str, str, int, int]] = [] + + def on_progress(coin: str, tf: str, pos: int, total: int) -> None: + progress_calls.append((coin, tf, pos, total)) + + runner = TrainerRunner( + market=market, + config=TradingConfig(coins=["BTC"]), + store=store, + base_dir=base_dir, + on_progress=on_progress, + ) + runner.run() + + # With 30 candles and pattern_length=2, there will be + # progress callbacks during weight adjustment + # (may or may not be called depending on candle count) + # Just verify no errors occurred + assert (base_dir / "memories_1hour.txt").exists()