From 2dffaea7f27006dfe8d1faa62c20ec5b25999f90 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Feb 2026 11:44:16 +0000 Subject: [PATCH] Add Phase 2: Domain data models and types Implement all data model classes that define the contracts between components (trainer, thinker, trader, hub): - models/types.py: Type aliases (Timeframe, CoinSymbol, SignalLevel, PriceLevel) - models/candle.py: Immutable Candle dataclass with derived properties (body_pct, range_pct, shadows, direction) - models/signal.py: Immutable Signal dataclass with entry/neutral helpers and per-timeframe boundary validation - models/position.py: Mutable Position dataclass with avg_price, pnl_pct, market_value, and DCA tracking - models/trade.py: Immutable Trade dataclass with to_dict/from_dict for JSON-lines serialisation (supports both new and legacy schemas) - models/memory.py: Mutable PatternMemory dataclass with to_memory_text/ from_memory_text for the custom on-disk format (~ and {} delimiters) 138 new unit tests covering construction, properties, validation, and serialisation round-trips. All 359 tests pass, ruff/mypy/format clean. https://claude.ai/code/session_01YCzrLj8szZHwkZvrghNf5V --- plan.md | 14 +- src/powertrader/models/__init__.py | 25 +++ src/powertrader/models/candle.py | 119 +++++++++++ src/powertrader/models/memory.py | 200 ++++++++++++++++++ src/powertrader/models/position.py | 104 ++++++++++ src/powertrader/models/signal.py | 91 ++++++++ src/powertrader/models/trade.py | 163 +++++++++++++++ src/powertrader/models/types.py | 21 ++ tests/unit/models/__init__.py | 0 tests/unit/models/test_candle.py | 217 +++++++++++++++++++ tests/unit/models/test_init.py | 54 +++++ tests/unit/models/test_memory.py | 288 ++++++++++++++++++++++++++ tests/unit/models/test_position.py | 188 +++++++++++++++++ tests/unit/models/test_signal.py | 191 +++++++++++++++++ tests/unit/models/test_trade.py | 322 +++++++++++++++++++++++++++++ tests/unit/models/test_types.py | 25 +++ 16 files changed, 2015 insertions(+), 7 deletions(-) create mode 100644 src/powertrader/models/candle.py create mode 100644 src/powertrader/models/memory.py create mode 100644 src/powertrader/models/position.py create mode 100644 src/powertrader/models/signal.py create mode 100644 src/powertrader/models/trade.py create mode 100644 src/powertrader/models/types.py create mode 100644 tests/unit/models/__init__.py create mode 100644 tests/unit/models/test_candle.py create mode 100644 tests/unit/models/test_init.py create mode 100644 tests/unit/models/test_memory.py create mode 100644 tests/unit/models/test_position.py create mode 100644 tests/unit/models/test_signal.py create mode 100644 tests/unit/models/test_trade.py create mode 100644 tests/unit/models/test_types.py diff --git a/plan.md b/plan.md index d2df49cf8..e3fa6cfbf 100644 --- a/plan.md +++ b/plan.md @@ -324,13 +324,13 @@ PriceLevel: TypeAlias = float ``` **Phase 2 Deliverables:** -- [ ] `models/candle.py` -- [ ] `models/signal.py` -- [ ] `models/position.py` -- [ ] `models/trade.py` -- [ ] `models/memory.py` -- [ ] `models/types.py` -- [ ] Unit tests for all model validation and properties +- [x] `models/candle.py` +- [x] `models/signal.py` +- [x] `models/position.py` +- [x] `models/trade.py` +- [x] `models/memory.py` +- [x] `models/types.py` +- [x] Unit tests for all model validation and properties (138 tests) --- diff --git a/src/powertrader/models/__init__.py b/src/powertrader/models/__init__.py index e69de29bb..0af61bf98 100644 --- a/src/powertrader/models/__init__.py +++ b/src/powertrader/models/__init__.py @@ -0,0 +1,25 @@ +"""Domain data models for PowerTrader AI. + +Re-exports all model classes for convenient imports:: + + from powertrader.models import Candle, Signal, Position, Trade, PatternMemory +""" + +from powertrader.models.candle import Candle +from powertrader.models.memory import PatternMemory +from powertrader.models.position import Position +from powertrader.models.signal import Signal +from powertrader.models.trade import Trade +from powertrader.models.types import CoinSymbol, PriceLevel, SignalLevel, Timeframe + +__all__ = [ + "Candle", + "CoinSymbol", + "PatternMemory", + "Position", + "PriceLevel", + "Signal", + "SignalLevel", + "Timeframe", + "Trade", +] diff --git a/src/powertrader/models/candle.py b/src/powertrader/models/candle.py new file mode 100644 index 000000000..8f063eb57 --- /dev/null +++ b/src/powertrader/models/candle.py @@ -0,0 +1,119 @@ +"""OHLCV candle data model. + +Represents a single candlestick bar as returned by market data APIs +(KuCoin klines). Immutable so it can be safely shared and cached. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class Candle: + """A single OHLCV candlestick bar. + + Parameters + ---------- + timestamp: + Candle open time as a Unix epoch in **seconds**. + open, high, low, close: + Price values for the bar. + volume: + Traded volume in the base asset during this bar. + """ + + timestamp: int + open: float + high: float + low: float + close: float + volume: float + + # -- derived properties --------------------------------------------------- + + @property + def body_pct(self) -> float: + """Percentage change from open to close: ``(close - open) / open * 100``. + + Returns ``0.0`` if *open* is zero (degenerate candle). + """ + if self.open == 0.0: + return 0.0 + return (self.close - self.open) / self.open * 100.0 + + @property + def range_pct(self) -> float: + """Total bar range as a percentage of the low: ``(high - low) / low * 100``. + + Returns ``0.0`` if *low* is zero. + """ + if self.low == 0.0: + return 0.0 + return (self.high - self.low) / self.low * 100.0 + + @property + def upper_shadow_pct(self) -> float: + """Upper shadow as a percentage of *open*: ``(high - max(open, close)) / open * 100``. + + Returns ``0.0`` if *open* is zero. + """ + if self.open == 0.0: + return 0.0 + top = max(self.open, self.close) + return (self.high - top) / self.open * 100.0 + + @property + def lower_shadow_pct(self) -> float: + """Lower shadow as a percentage of *open*: ``(min(open, close) - low) / open * 100``. + + Returns ``0.0`` if *open* is zero. + """ + if self.open == 0.0: + return 0.0 + bottom = min(self.open, self.close) + return (bottom - self.low) / self.open * 100.0 + + @property + def is_bullish(self) -> bool: + """``True`` if the close is strictly above the open.""" + return self.close > self.open + + @property + def is_bearish(self) -> bool: + """``True`` if the close is strictly below the open.""" + return self.close < self.open + + @property + def mid(self) -> float: + """Midpoint price: ``(high + low) / 2``.""" + return (self.high + self.low) / 2.0 + + # -- validation ----------------------------------------------------------- + + def validate(self) -> list[str]: + """Return a list of validation errors (empty means valid).""" + errors: list[str] = [] + if self.timestamp < 0: + errors.append(f"timestamp={self.timestamp} must be >= 0.") + if self.open < 0: + errors.append(f"open={self.open} must be >= 0.") + if self.high < 0: + errors.append(f"high={self.high} must be >= 0.") + if self.low < 0: + errors.append(f"low={self.low} must be >= 0.") + if self.close < 0: + errors.append(f"close={self.close} must be >= 0.") + if self.volume < 0: + errors.append(f"volume={self.volume} must be >= 0.") + if self.high < self.low: + errors.append(f"high={self.high} must be >= low={self.low}.") + if self.high < self.open: + errors.append(f"high={self.high} must be >= open={self.open}.") + if self.high < self.close: + errors.append(f"high={self.high} must be >= close={self.close}.") + if self.low > self.open: + errors.append(f"low={self.low} must be <= open={self.open}.") + if self.low > self.close: + errors.append(f"low={self.low} must be <= close={self.close}.") + return errors diff --git a/src/powertrader/models/memory.py b/src/powertrader/models/memory.py new file mode 100644 index 000000000..b05781ff5 --- /dev/null +++ b/src/powertrader/models/memory.py @@ -0,0 +1,200 @@ +"""Pattern memory data model. + +A :class:`PatternMemory` holds the trained price-pattern data for a +single coin on a single timeframe. The trainer builds these from +historical kline data, and the thinker reads them to generate signals. + +File format +----------- +``memories_.txt`` uses a custom delimited text format: + +- Patterns are separated by ``~`` +- Each pattern has three fields separated by ``{}``:: + + candle_pcts{}high_diff{}low_diff + + where *candle_pcts* is a space-separated sequence of percentage + changes, *high_diff* is the predicted-high deviation, and *low_diff* + is the predicted-low deviation. + +Parallel weight files (``memory_weights_.txt``, etc.) contain +space-separated floats — one weight per pattern. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +# Delimiters used in the on-disk memory format. +PATTERN_SEPARATOR: str = "~" +FIELD_SEPARATOR: str = "{}" + + +@dataclass(slots=True) +class PatternMemory: + """Trained pattern memory for one coin / one timeframe. + + Parameters + ---------- + patterns: + List of patterns. Each pattern is a list of floats representing + the candle-body percentage changes that define the shape. + high_diffs: + Predicted-high deviation for each pattern (parallel to *patterns*). + low_diffs: + Predicted-low deviation for each pattern (parallel to *patterns*). + weights: + Reliability weight for the base prediction of each pattern. + weights_high: + Reliability weight for the high prediction of each pattern. + weights_low: + Reliability weight for the low prediction of each pattern. + threshold: + Maximum distance for a current candle sequence to "match" a + stored pattern. + """ + + patterns: list[list[float]] = field(default_factory=list) + high_diffs: list[float] = field(default_factory=list) + low_diffs: list[float] = field(default_factory=list) + weights: list[float] = field(default_factory=list) + weights_high: list[float] = field(default_factory=list) + weights_low: list[float] = field(default_factory=list) + threshold: float = 1.0 + + # -- derived properties --------------------------------------------------- + + @property + def size(self) -> int: + """Number of stored patterns.""" + return len(self.patterns) + + @property + def is_empty(self) -> bool: + """``True`` if no patterns have been stored.""" + return len(self.patterns) == 0 + + # -- serialisation (on-disk format) --------------------------------------- + + def to_memory_text(self) -> str: + """Serialise patterns to the ``memories_.txt`` format. + + Each pattern is rendered as:: + + candle_pct1 candle_pct2{}high_diff{}low_diff + + Patterns are joined by ``~``. + """ + parts: list[str] = [] + for i, pat in enumerate(self.patterns): + candle_str = " ".join(str(v) for v in pat) + high = self.high_diffs[i] if i < len(self.high_diffs) else 0.0 + low = self.low_diffs[i] if i < len(self.low_diffs) else 0.0 + parts.append(f"{candle_str}{FIELD_SEPARATOR}{high}{FIELD_SEPARATOR}{low}") + return PATTERN_SEPARATOR.join(parts) + + @classmethod + def from_memory_text( + cls, + text: str, + weights_text: str = "", + weights_high_text: str = "", + weights_low_text: str = "", + threshold: float = 1.0, + ) -> PatternMemory: + """Parse from the on-disk ``memories_.txt`` format. + + Parameters + ---------- + text: + Contents of ``memories_.txt``. + weights_text: + Contents of ``memory_weights_.txt`` (space-separated floats). + weights_high_text: + Contents of ``memory_weights_high_.txt``. + weights_low_text: + Contents of ``memory_weights_low_.txt``. + threshold: + Value from ``neural_perfect_threshold_.txt``. + """ + patterns: list[list[float]] = [] + high_diffs: list[float] = [] + low_diffs: list[float] = [] + + raw_patterns = text.strip().split(PATTERN_SEPARATOR) if text.strip() else [] + for raw in raw_patterns: + raw = raw.strip() + if not raw: + continue + fields = raw.split(FIELD_SEPARATOR) + # Field 0: candle percentages (space-separated) + candle_pcts = _parse_floats_space(fields[0]) if fields else [] + if not candle_pcts: + continue + patterns.append(candle_pcts) + # Field 1: high_diff + high_diffs.append(_safe_float(fields[1]) if len(fields) > 1 else 0.0) + # Field 2: low_diff + low_diffs.append(_safe_float(fields[2]) if len(fields) > 2 else 0.0) + + return cls( + patterns=patterns, + high_diffs=high_diffs, + low_diffs=low_diffs, + weights=_parse_floats_space(weights_text), + weights_high=_parse_floats_space(weights_high_text), + weights_low=_parse_floats_space(weights_low_text), + threshold=threshold, + ) + + # -- validation ----------------------------------------------------------- + + def validate(self) -> list[str]: + """Return a list of validation errors (empty means valid).""" + errors: list[str] = [] + n = len(self.patterns) + if len(self.high_diffs) != n: + errors.append(f"high_diffs length ({len(self.high_diffs)}) != patterns length ({n}).") + if len(self.low_diffs) != n: + errors.append(f"low_diffs length ({len(self.low_diffs)}) != patterns length ({n}).") + if self.weights and len(self.weights) != n: + errors.append(f"weights length ({len(self.weights)}) != patterns length ({n}).") + if self.weights_high and len(self.weights_high) != n: + errors.append( + f"weights_high length ({len(self.weights_high)}) != patterns length ({n})." + ) + if self.weights_low and len(self.weights_low) != n: + errors.append( + f"weights_low length ({len(self.weights_low)}) != patterns length ({n})." + ) + if self.threshold < 0: + errors.append(f"threshold={self.threshold} must be >= 0.") + return errors + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _parse_floats_space(text: str) -> list[float]: + """Parse space-separated floats, skipping blanks.""" + if not text or not text.strip(): + return [] + result: list[float] = [] + for tok in text.strip().split(): + tok = tok.strip() + if tok: + try: + result.append(float(tok)) + except ValueError: + continue + return result + + +def _safe_float(text: str) -> float: + """Parse a single float, returning ``0.0`` on failure.""" + try: + return float(text.strip()) + except (ValueError, AttributeError): + return 0.0 diff --git a/src/powertrader/models/position.py b/src/powertrader/models/position.py new file mode 100644 index 000000000..fd0e3cfd6 --- /dev/null +++ b/src/powertrader/models/position.py @@ -0,0 +1,104 @@ +"""Open position data model. + +A :class:`Position` tracks the current state of a coin holding — +cost basis, DCA history, and trailing profit-margin state. + +Unlike the other models this is *mutable* because the trader updates +position state in-place as prices change and DCA buys occur. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(slots=True) +class Position: + """Tracks a single open coin position. + + Parameters + ---------- + coin: + Coin ticker, e.g. ``"BTC"``. + entry_price: + Price at which the initial buy was executed. + quantity: + Total quantity held (base asset units). + cost_basis_usd: + Total USD spent to acquire the current quantity + (sum of all buys including DCA). + dca_count: + Number of DCA buys executed for this position. + dca_timestamps: + Unix epoch (seconds) of each DCA buy, used for the rolling + 24-hour rate limit window. + trailing_active: + ``True`` when the price has reached the profit-margin start + line and trailing tracking is engaged. + trailing_peak: + Highest price observed since trailing became active. + trailing_line: + Current trailing exit line (peak minus trailing gap). + """ + + coin: str + entry_price: float + quantity: float + cost_basis_usd: float = 0.0 + dca_count: int = 0 + dca_timestamps: list[float] = field(default_factory=list) + trailing_active: bool = False + trailing_peak: float = 0.0 + trailing_line: float = 0.0 + + # -- derived properties --------------------------------------------------- + + @property + def avg_price(self) -> float: + """Average cost per unit: ``cost_basis_usd / quantity``. + + Returns ``0.0`` if *quantity* is zero. + """ + if self.quantity == 0.0: + return 0.0 + return self.cost_basis_usd / self.quantity + + @property + def has_dca(self) -> bool: + """``True`` if at least one DCA buy has been executed.""" + return self.dca_count > 0 + + def pnl_pct(self, current_price: float) -> float: + """Unrealised PnL percentage at *current_price*. + + Returns ``0.0`` if *avg_price* is zero. + """ + avg = self.avg_price + if avg == 0.0: + return 0.0 + return (current_price - avg) / avg * 100.0 + + def market_value(self, current_price: float) -> float: + """Current market value of the position in quote currency.""" + return self.quantity * current_price + + # -- validation ----------------------------------------------------------- + + def validate(self) -> list[str]: + """Return a list of validation errors (empty means valid).""" + errors: list[str] = [] + if not self.coin: + errors.append("coin must not be empty.") + if self.entry_price < 0: + errors.append(f"entry_price={self.entry_price} must be >= 0.") + if self.quantity < 0: + errors.append(f"quantity={self.quantity} must be >= 0.") + if self.cost_basis_usd < 0: + errors.append(f"cost_basis_usd={self.cost_basis_usd} must be >= 0.") + if self.dca_count < 0: + errors.append(f"dca_count={self.dca_count} must be >= 0.") + if self.trailing_peak < 0: + errors.append(f"trailing_peak={self.trailing_peak} must be >= 0.") + if self.trailing_line < 0: + errors.append(f"trailing_line={self.trailing_line} must be >= 0.") + return errors diff --git a/src/powertrader/models/signal.py b/src/powertrader/models/signal.py new file mode 100644 index 000000000..4c0af6d73 --- /dev/null +++ b/src/powertrader/models/signal.py @@ -0,0 +1,91 @@ +"""Trading signal data model. + +A :class:`Signal` is the output of the thinker — it summarises how many +predicted high/low boundary levels the current price has broken through +for a given coin, plus the per-timeframe boundary prices. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from powertrader.core.constants import SIGNAL_MAX, SIGNAL_MIN, TIMEFRAMES + +NUM_TIMEFRAMES: int = len(TIMEFRAMES) + + +@dataclass(frozen=True, slots=True) +class Signal: + """Snapshot of the neural signal state for a single coin. + + Parameters + ---------- + coin: + Coin ticker, e.g. ``"BTC"``. + long_level: + LONG signal strength (0 = no signal, 7 = max confidence). + short_level: + SHORT signal strength (0 = no signal, 7 = max confidence). + long_bounds: + Per-timeframe low boundary prices (7 values, one per timeframe). + These are the predicted support levels — price breaking *below* + each level increments ``long_level``. + short_bounds: + Per-timeframe high boundary prices (7 values, one per timeframe). + These are the predicted resistance levels — price breaking *above* + each level increments ``short_level``. + long_profit_margin: + Aggregated profit-margin hint for long trades (percentage). + short_profit_margin: + Aggregated profit-margin hint for short trades (percentage). + timestamp: + Unix epoch (seconds) when this signal was generated. + """ + + coin: str + long_level: int = 0 + short_level: int = 0 + long_bounds: list[float] = field(default_factory=list) + short_bounds: list[float] = field(default_factory=list) + long_profit_margin: float = 0.0 + short_profit_margin: float = 0.0 + timestamp: float = 0.0 + + # -- convenience ---------------------------------------------------------- + + @property + def is_long_entry(self) -> bool: + """``True`` when long signal is strong and no short signal.""" + return self.long_level >= 3 and self.short_level == 0 + + @property + def is_neutral(self) -> bool: + """``True`` when both signal levels are zero.""" + return self.long_level == 0 and self.short_level == 0 + + # -- validation ----------------------------------------------------------- + + def validate(self) -> list[str]: + """Return a list of validation errors (empty means valid).""" + errors: list[str] = [] + if not self.coin: + errors.append("coin must not be empty.") + if not SIGNAL_MIN <= self.long_level <= SIGNAL_MAX: + errors.append(f"long_level={self.long_level} outside {SIGNAL_MIN}-{SIGNAL_MAX} range.") + if not SIGNAL_MIN <= self.short_level <= SIGNAL_MAX: + errors.append( + f"short_level={self.short_level} outside {SIGNAL_MIN}-{SIGNAL_MAX} range." + ) + if len(self.long_bounds) not in (0, NUM_TIMEFRAMES): + errors.append( + f"long_bounds has {len(self.long_bounds)} elements, " + f"expected 0 or {NUM_TIMEFRAMES}." + ) + if len(self.short_bounds) not in (0, NUM_TIMEFRAMES): + errors.append( + f"short_bounds has {len(self.short_bounds)} elements, " + f"expected 0 or {NUM_TIMEFRAMES}." + ) + if self.timestamp < 0: + errors.append(f"timestamp={self.timestamp} must be >= 0.") + return errors diff --git a/src/powertrader/models/trade.py b/src/powertrader/models/trade.py new file mode 100644 index 000000000..39e48c2f4 --- /dev/null +++ b/src/powertrader/models/trade.py @@ -0,0 +1,163 @@ +"""Executed trade data model. + +A :class:`Trade` is an immutable record of a single executed order — buy +or sell. It is appended to ``trade_history.jsonl`` by the trader and +displayed in the Hub GUI. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +_VALID_SIDES = frozenset({"BUY", "SELL"}) + + +@dataclass(frozen=True, slots=True) +class Trade: + """Record of a single executed trade. + + Parameters + ---------- + coin: + Coin ticker, e.g. ``"BTC"``. + side: + ``"BUY"`` or ``"SELL"``. + price: + Execution / average fill price. + quantity: + Quantity in base asset units. + value: + Total quote value of the trade (``price * quantity``). + reason: + Why this trade was placed — e.g. ``"entry"``, ``"dca_stage_1"``, + ``"dca_stage_3"``, ``"trailing_exit"``. + timestamp: + Unix epoch (seconds) when the order was filled. + pnl_pct: + Realised profit/loss percentage for sell trades. + ``None`` for buy trades. + fees_usd: + Exchange fees paid in USD (if known). + order_id: + Exchange order ID string (if available). + """ + + coin: str + side: str + price: float + quantity: float + value: float + reason: str + timestamp: float + pnl_pct: float | None = None + fees_usd: float | None = None + order_id: str | None = None + + # -- convenience ---------------------------------------------------------- + + @property + def is_buy(self) -> bool: + """``True`` if this trade is a buy.""" + return self.side == "BUY" + + @property + def is_sell(self) -> bool: + """``True`` if this trade is a sell.""" + return self.side == "SELL" + + @property + def is_dca(self) -> bool: + """``True`` if the trade reason indicates a DCA buy.""" + return self.reason.startswith("dca_") + + # -- serialisation -------------------------------------------------------- + + def to_dict(self) -> dict[str, object]: + """Convert to a dictionary suitable for JSON-lines serialisation. + + Keys match the existing ``trade_history.jsonl`` schema used by + the trader. + """ + return { + "ts": self.timestamp, + "side": self.side.lower(), + "tag": self.reason, + "symbol": self.coin, + "qty": self.quantity, + "price": self.price, + "pnl_pct": self.pnl_pct, + "fees_usd": self.fees_usd, + "order_id": self.order_id, + } + + @classmethod + def from_dict(cls, data: dict[str, object]) -> Trade: + """Reconstruct a Trade from a ``trade_history.jsonl`` record. + + Handles both the new schema (``coin``, ``side`` upper) and the + legacy schema (``symbol``, ``side`` lower, ``ts``, ``tag``). + """ + side_raw = str(data.get("side", "BUY")).upper() + coin = str(data.get("coin") or data.get("symbol") or "") + + def _get_float(key: str, *alt_keys: str, default: float = 0.0) -> float: + for k in (key, *alt_keys): + v = data.get(k) + if v is not None: + try: + return float(str(v)) + except (TypeError, ValueError): + continue + return default + + return cls( + coin=coin, + side=side_raw, + price=_get_float("price"), + quantity=_get_float("qty", "quantity"), + value=_get_float("value"), + reason=str(data.get("reason") or data.get("tag") or ""), + timestamp=_get_float("timestamp", "ts"), + pnl_pct=_opt_float(data.get("pnl_pct")), + fees_usd=_opt_float(data.get("fees_usd")), + order_id=_opt_str(data.get("order_id")), + ) + + # -- validation ----------------------------------------------------------- + + def validate(self) -> list[str]: + """Return a list of validation errors (empty means valid).""" + errors: list[str] = [] + if not self.coin: + errors.append("coin must not be empty.") + if self.side not in _VALID_SIDES: + errors.append(f"side={self.side!r} must be 'BUY' or 'SELL'.") + if self.price < 0: + errors.append(f"price={self.price} must be >= 0.") + if self.quantity < 0: + errors.append(f"quantity={self.quantity} must be >= 0.") + if self.value < 0: + errors.append(f"value={self.value} must be >= 0.") + if self.timestamp < 0: + errors.append(f"timestamp={self.timestamp} must be >= 0.") + return errors + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _opt_float(val: object) -> float | None: + if val is None: + return None + try: + return float(val) # type: ignore[arg-type] + except (TypeError, ValueError): + return None + + +def _opt_str(val: object) -> str | None: + if val is None: + return None + return str(val) diff --git a/src/powertrader/models/types.py b/src/powertrader/models/types.py new file mode 100644 index 000000000..a64d59b13 --- /dev/null +++ b/src/powertrader/models/types.py @@ -0,0 +1,21 @@ +"""Domain-specific type aliases for PowerTrader AI. + +These aliases document intent at call sites without introducing runtime cost. +Use them in type annotations to make function signatures self-documenting. +""" + +from __future__ import annotations + +from typing import TypeAlias + +# A timeframe identifier — one of the values in ``core.constants.TIMEFRAMES``. +Timeframe: TypeAlias = str + +# A coin ticker symbol, e.g. ``"BTC"``, ``"ETH"``. +CoinSymbol: TypeAlias = str + +# A neural signal level in the range 0-7 (inclusive). +SignalLevel: TypeAlias = int + +# A price value (always positive float). +PriceLevel: TypeAlias = float diff --git a/tests/unit/models/__init__.py b/tests/unit/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/models/test_candle.py b/tests/unit/models/test_candle.py new file mode 100644 index 000000000..1c8ae48b0 --- /dev/null +++ b/tests/unit/models/test_candle.py @@ -0,0 +1,217 @@ +"""Tests for powertrader.models.candle.""" + +from __future__ import annotations + +import pytest + +from powertrader.models.candle import Candle + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bullish_candle() -> Candle: + """A standard bullish candle: close > open.""" + return Candle( + timestamp=1700000000, + open=100.0, + high=110.0, + low=95.0, + close=108.0, + volume=500.0, + ) + + +@pytest.fixture +def bearish_candle() -> Candle: + """A standard bearish candle: close < open.""" + return Candle( + timestamp=1700000000, + open=108.0, + high=110.0, + low=95.0, + close=100.0, + volume=300.0, + ) + + +@pytest.fixture +def doji_candle() -> Candle: + """A doji candle: open == close.""" + return Candle( + timestamp=1700000000, + open=100.0, + high=105.0, + low=95.0, + close=100.0, + volume=200.0, + ) + + +# --------------------------------------------------------------------------- +# Construction & immutability +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_fields_stored(self, bullish_candle: Candle) -> None: + assert bullish_candle.timestamp == 1700000000 + assert bullish_candle.open == 100.0 + assert bullish_candle.high == 110.0 + assert bullish_candle.low == 95.0 + assert bullish_candle.close == 108.0 + assert bullish_candle.volume == 500.0 + + def test_frozen(self, bullish_candle: Candle) -> None: + with pytest.raises(AttributeError): + bullish_candle.close = 999.0 # type: ignore[misc] + + def test_equality(self) -> None: + a = Candle(1, 100.0, 110.0, 90.0, 105.0, 10.0) + b = Candle(1, 100.0, 110.0, 90.0, 105.0, 10.0) + assert a == b + + def test_inequality(self) -> None: + a = Candle(1, 100.0, 110.0, 90.0, 105.0, 10.0) + b = Candle(1, 100.0, 110.0, 90.0, 106.0, 10.0) + assert a != b + + +# --------------------------------------------------------------------------- +# Derived properties +# --------------------------------------------------------------------------- + + +class TestBodyPct: + def test_bullish(self, bullish_candle: Candle) -> None: + # (108 - 100) / 100 * 100 = 8.0% + assert bullish_candle.body_pct == pytest.approx(8.0) + + def test_bearish(self, bearish_candle: Candle) -> None: + # (100 - 108) / 108 * 100 ≈ -7.407% + assert bearish_candle.body_pct == pytest.approx(-7.407407, rel=1e-4) + + def test_doji(self, doji_candle: Candle) -> None: + assert doji_candle.body_pct == pytest.approx(0.0) + + def test_zero_open(self) -> None: + c = Candle(0, 0.0, 10.0, 0.0, 5.0, 1.0) + assert c.body_pct == 0.0 + + +class TestRangePct: + def test_normal(self, bullish_candle: Candle) -> None: + # (110 - 95) / 95 * 100 ≈ 15.789% + assert bullish_candle.range_pct == pytest.approx(15.789473, rel=1e-4) + + def test_zero_low(self) -> None: + c = Candle(0, 0.0, 10.0, 0.0, 5.0, 1.0) + assert c.range_pct == 0.0 + + def test_flat_candle(self) -> None: + c = Candle(0, 100.0, 100.0, 100.0, 100.0, 1.0) + assert c.range_pct == 0.0 + + +class TestShadows: + def test_upper_shadow_bullish(self, bullish_candle: Candle) -> None: + # upper shadow = high - max(open, close) = 110 - 108 = 2 + # as % of open: 2/100*100 = 2.0% + assert bullish_candle.upper_shadow_pct == pytest.approx(2.0) + + def test_upper_shadow_bearish(self, bearish_candle: Candle) -> None: + # upper shadow = 110 - max(108, 100) = 110 - 108 = 2 + # as % of open: 2/108*100 ≈ 1.852% + assert bearish_candle.upper_shadow_pct == pytest.approx(1.8518, rel=1e-3) + + def test_lower_shadow_bullish(self, bullish_candle: Candle) -> None: + # lower shadow = min(100, 108) - 95 = 100 - 95 = 5 + # as % of open: 5/100*100 = 5.0% + assert bullish_candle.lower_shadow_pct == pytest.approx(5.0) + + def test_lower_shadow_bearish(self, bearish_candle: Candle) -> None: + # lower shadow = min(108, 100) - 95 = 100 - 95 = 5 + # as % of open: 5/108*100 ≈ 4.630% + assert bearish_candle.lower_shadow_pct == pytest.approx(4.6296, rel=1e-3) + + def test_zero_open_shadows(self) -> None: + c = Candle(0, 0.0, 10.0, 0.0, 5.0, 1.0) + assert c.upper_shadow_pct == 0.0 + assert c.lower_shadow_pct == 0.0 + + +class TestDirectionProperties: + def test_bullish(self, bullish_candle: Candle) -> None: + assert bullish_candle.is_bullish is True + assert bullish_candle.is_bearish is False + + def test_bearish(self, bearish_candle: Candle) -> None: + assert bearish_candle.is_bullish is False + assert bearish_candle.is_bearish is True + + def test_doji(self, doji_candle: Candle) -> None: + assert doji_candle.is_bullish is False + assert doji_candle.is_bearish is False + + +class TestMid: + def test_mid(self, bullish_candle: Candle) -> None: + # (110 + 95) / 2 = 102.5 + assert bullish_candle.mid == pytest.approx(102.5) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_valid_candle(self, bullish_candle: Candle) -> None: + assert bullish_candle.validate() == [] + + def test_negative_timestamp(self) -> None: + c = Candle(-1, 100.0, 110.0, 90.0, 105.0, 10.0) + errors = c.validate() + assert any("timestamp" in e for e in errors) + + def test_negative_prices(self) -> None: + c = Candle(0, -1.0, 110.0, 90.0, 105.0, 10.0) + errors = c.validate() + assert any("open" in e for e in errors) + + def test_negative_volume(self) -> None: + c = Candle(0, 100.0, 110.0, 90.0, 105.0, -1.0) + errors = c.validate() + assert any("volume" in e for e in errors) + + def test_high_below_low(self) -> None: + c = Candle(0, 100.0, 90.0, 110.0, 105.0, 10.0) + errors = c.validate() + assert any("high" in e and "low" in e for e in errors) + + def test_high_below_open(self) -> None: + c = Candle(0, 100.0, 95.0, 90.0, 93.0, 10.0) + errors = c.validate() + assert any("high" in e and "open" in e for e in errors) + + def test_high_below_close(self) -> None: + c = Candle(0, 90.0, 95.0, 85.0, 100.0, 10.0) + errors = c.validate() + assert any("high" in e and "close" in e for e in errors) + + def test_low_above_open(self) -> None: + c = Candle(0, 90.0, 110.0, 95.0, 105.0, 10.0) + errors = c.validate() + assert any("low" in e and "open" in e for e in errors) + + def test_low_above_close(self) -> None: + c = Candle(0, 100.0, 110.0, 95.0, 90.0, 10.0) + errors = c.validate() + assert any("low" in e and "close" in e for e in errors) + + def test_zero_prices_valid(self) -> None: + """Zero prices are allowed (degenerate but not invalid).""" + c = Candle(0, 0.0, 0.0, 0.0, 0.0, 0.0) + assert c.validate() == [] diff --git a/tests/unit/models/test_init.py b/tests/unit/models/test_init.py new file mode 100644 index 000000000..6409b2b88 --- /dev/null +++ b/tests/unit/models/test_init.py @@ -0,0 +1,54 @@ +"""Tests for powertrader.models package-level imports.""" + +from __future__ import annotations + + +class TestPackageImports: + """All model classes should be importable from the package root.""" + + def test_candle_importable(self) -> None: + from powertrader.models import Candle + + c = Candle(0, 100.0, 110.0, 90.0, 105.0, 10.0) + assert c.close == 105.0 + + def test_signal_importable(self) -> None: + from powertrader.models import Signal + + s = Signal(coin="BTC", long_level=3) + assert s.long_level == 3 + + def test_position_importable(self) -> None: + from powertrader.models import Position + + p = Position(coin="BTC", entry_price=100.0, quantity=1.0) + assert p.coin == "BTC" + + def test_trade_importable(self) -> None: + from powertrader.models import Trade + + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + assert t.is_buy + + def test_pattern_memory_importable(self) -> None: + from powertrader.models import PatternMemory + + m = PatternMemory() + assert m.is_empty + + def test_type_aliases_importable(self) -> None: + from powertrader.models import CoinSymbol, PriceLevel, SignalLevel, Timeframe + + # These are just aliases, verify they exist + assert Timeframe is not None + assert CoinSymbol is not None + assert SignalLevel is not None + assert PriceLevel is not None diff --git a/tests/unit/models/test_memory.py b/tests/unit/models/test_memory.py new file mode 100644 index 000000000..157454587 --- /dev/null +++ b/tests/unit/models/test_memory.py @@ -0,0 +1,288 @@ +"""Tests for powertrader.models.memory.""" + +from __future__ import annotations + +import pytest + +from powertrader.models.memory import FIELD_SEPARATOR, PATTERN_SEPARATOR, PatternMemory + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_memory() -> PatternMemory: + """A small memory with 3 patterns.""" + return PatternMemory( + patterns=[[1.5, 0.8], [-0.5, 0.3], [2.0, -1.0]], + high_diffs=[2.3, -1.2, 3.0], + low_diffs=[1.1, 0.8, -0.5], + weights=[1.0, 0.5, 0.8], + weights_high=[0.9, 0.6, 0.7], + weights_low=[1.1, 0.4, 0.9], + threshold=0.85, + ) + + +@pytest.fixture +def empty_memory() -> PatternMemory: + """An empty memory with no patterns.""" + return PatternMemory() + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_fields_stored(self, simple_memory: PatternMemory) -> None: + assert len(simple_memory.patterns) == 3 + assert simple_memory.patterns[0] == [1.5, 0.8] + assert simple_memory.high_diffs == [2.3, -1.2, 3.0] + assert simple_memory.low_diffs == [1.1, 0.8, -0.5] + assert simple_memory.weights == [1.0, 0.5, 0.8] + assert simple_memory.weights_high == [0.9, 0.6, 0.7] + assert simple_memory.weights_low == [1.1, 0.4, 0.9] + assert simple_memory.threshold == 0.85 + + def test_defaults(self, empty_memory: PatternMemory) -> None: + assert empty_memory.patterns == [] + assert empty_memory.high_diffs == [] + assert empty_memory.low_diffs == [] + assert empty_memory.weights == [] + assert empty_memory.weights_high == [] + assert empty_memory.weights_low == [] + assert empty_memory.threshold == 1.0 + + def test_mutable(self, simple_memory: PatternMemory) -> None: + """Memory is mutable — the trainer adjusts weights in-place.""" + simple_memory.weights[0] = 1.5 + assert simple_memory.weights[0] == 1.5 + + +# --------------------------------------------------------------------------- +# Derived properties +# --------------------------------------------------------------------------- + + +class TestProperties: + def test_size(self, simple_memory: PatternMemory) -> None: + assert simple_memory.size == 3 + + def test_size_empty(self, empty_memory: PatternMemory) -> None: + assert empty_memory.size == 0 + + def test_is_empty(self, empty_memory: PatternMemory) -> None: + assert empty_memory.is_empty is True + + def test_is_not_empty(self, simple_memory: PatternMemory) -> None: + assert simple_memory.is_empty is False + + +# --------------------------------------------------------------------------- +# Serialisation: to_memory_text +# --------------------------------------------------------------------------- + + +class TestToMemoryText: + def test_round_trip_structure(self, simple_memory: PatternMemory) -> None: + text = simple_memory.to_memory_text() + # Should have 3 patterns separated by ~ + parts = text.split(PATTERN_SEPARATOR) + assert len(parts) == 3 + + def test_pattern_format(self, simple_memory: PatternMemory) -> None: + text = simple_memory.to_memory_text() + first_pattern = text.split(PATTERN_SEPARATOR)[0] + fields = first_pattern.split(FIELD_SEPARATOR) + assert len(fields) == 3 + # candle pcts + assert "1.5" in fields[0] + assert "0.8" in fields[0] + # high_diff + assert fields[1].strip() == "2.3" + # low_diff + assert fields[2].strip() == "1.1" + + def test_empty_memory(self, empty_memory: PatternMemory) -> None: + assert empty_memory.to_memory_text() == "" + + +# --------------------------------------------------------------------------- +# Serialisation: from_memory_text +# --------------------------------------------------------------------------- + + +class TestFromMemoryText: + def test_basic_parse(self) -> None: + text = "1.5 0.8{}2.3{}1.1~-0.5 0.3{}-1.2{}0.8" + weights_text = "1.0 0.5" + weights_high_text = "0.9 0.6" + weights_low_text = "1.1 0.4" + + mem = PatternMemory.from_memory_text( + text, + weights_text=weights_text, + weights_high_text=weights_high_text, + weights_low_text=weights_low_text, + threshold=0.85, + ) + + assert mem.size == 2 + assert mem.patterns[0] == [1.5, 0.8] + assert mem.patterns[1] == [-0.5, 0.3] + assert mem.high_diffs == [2.3, -1.2] + assert mem.low_diffs == [1.1, 0.8] + assert mem.weights == [1.0, 0.5] + assert mem.weights_high == [0.9, 0.6] + assert mem.weights_low == [1.1, 0.4] + assert mem.threshold == 0.85 + + def test_empty_text(self) -> None: + mem = PatternMemory.from_memory_text("") + assert mem.is_empty + assert mem.threshold == 1.0 + + def test_whitespace_only(self) -> None: + mem = PatternMemory.from_memory_text(" \n ") + assert mem.is_empty + + def test_single_pattern(self) -> None: + text = "3.5 1.2 -0.8{}4.0{}2.0" + mem = PatternMemory.from_memory_text(text) + assert mem.size == 1 + assert mem.patterns[0] == [3.5, 1.2, -0.8] + assert mem.high_diffs == [4.0] + assert mem.low_diffs == [2.0] + + def test_missing_fields_default_to_zero(self) -> None: + """Pattern with only candle pcts, no high/low diffs.""" + text = "1.5 0.8" + mem = PatternMemory.from_memory_text(text) + assert mem.size == 1 + assert mem.high_diffs == [0.0] + assert mem.low_diffs == [0.0] + + def test_no_weights_yields_empty(self) -> None: + text = "1.0 2.0{}3.0{}4.0" + mem = PatternMemory.from_memory_text(text) + assert mem.weights == [] + assert mem.weights_high == [] + assert mem.weights_low == [] + + def test_blank_patterns_skipped(self) -> None: + """Blank entries between separators are skipped.""" + text = "1.5 0.8{}2.3{}1.1~~-0.5 0.3{}-1.2{}0.8" + mem = PatternMemory.from_memory_text(text) + assert mem.size == 2 + + +class TestRoundTrip: + def test_to_then_from(self, simple_memory: PatternMemory) -> None: + text = simple_memory.to_memory_text() + weights_text = " ".join(str(w) for w in simple_memory.weights) + weights_high_text = " ".join(str(w) for w in simple_memory.weights_high) + weights_low_text = " ".join(str(w) for w in simple_memory.weights_low) + + reconstructed = PatternMemory.from_memory_text( + text, + weights_text=weights_text, + weights_high_text=weights_high_text, + weights_low_text=weights_low_text, + threshold=simple_memory.threshold, + ) + + assert reconstructed.size == simple_memory.size + assert reconstructed.threshold == simple_memory.threshold + for i in range(simple_memory.size): + for j in range(len(simple_memory.patterns[i])): + assert reconstructed.patterns[i][j] == pytest.approx(simple_memory.patterns[i][j]) + assert reconstructed.high_diffs[i] == pytest.approx(simple_memory.high_diffs[i]) + assert reconstructed.low_diffs[i] == pytest.approx(simple_memory.low_diffs[i]) + assert reconstructed.weights == pytest.approx(simple_memory.weights) + assert reconstructed.weights_high == pytest.approx(simple_memory.weights_high) + assert reconstructed.weights_low == pytest.approx(simple_memory.weights_low) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_valid_memory(self, simple_memory: PatternMemory) -> None: + assert simple_memory.validate() == [] + + def test_valid_empty(self, empty_memory: PatternMemory) -> None: + assert empty_memory.validate() == [] + + def test_mismatched_high_diffs(self) -> None: + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0], # should be 2 + low_diffs=[1.0, 2.0], + ) + errors = mem.validate() + assert any("high_diffs" in e for e in errors) + + def test_mismatched_low_diffs(self) -> None: + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0, 2.0], + low_diffs=[1.0], # should be 2 + ) + errors = mem.validate() + assert any("low_diffs" in e for e in errors) + + def test_mismatched_weights(self) -> None: + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0, 2.0], + low_diffs=[1.0, 2.0], + weights=[1.0], # should be 2 (or empty) + ) + errors = mem.validate() + assert any("weights length" in e for e in errors) + + def test_mismatched_weights_high(self) -> None: + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0, 2.0], + low_diffs=[1.0, 2.0], + weights_high=[1.0], # should be 2 (or empty) + ) + errors = mem.validate() + assert any("weights_high" in e for e in errors) + + def test_mismatched_weights_low(self) -> None: + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0, 2.0], + low_diffs=[1.0, 2.0], + weights_low=[1.0], # should be 2 (or empty) + ) + errors = mem.validate() + assert any("weights_low" in e for e in errors) + + def test_empty_weights_valid(self) -> None: + """Empty weights are valid (means all patterns have default weight).""" + mem = PatternMemory( + patterns=[[1.0], [2.0]], + high_diffs=[1.0, 2.0], + low_diffs=[1.0, 2.0], + weights=[], + weights_high=[], + weights_low=[], + ) + assert mem.validate() == [] + + def test_negative_threshold(self) -> None: + mem = PatternMemory(threshold=-0.1) + errors = mem.validate() + assert any("threshold" in e for e in errors) + + def test_zero_threshold_valid(self) -> None: + mem = PatternMemory(threshold=0.0) + assert mem.validate() == [] diff --git a/tests/unit/models/test_position.py b/tests/unit/models/test_position.py new file mode 100644 index 000000000..e1416cc31 --- /dev/null +++ b/tests/unit/models/test_position.py @@ -0,0 +1,188 @@ +"""Tests for powertrader.models.position.""" + +from __future__ import annotations + +import pytest + +from powertrader.models.position import Position + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fresh_position() -> Position: + """A newly opened position with no DCA.""" + return Position( + coin="BTC", + entry_price=42000.0, + quantity=0.01, + cost_basis_usd=420.0, + ) + + +@pytest.fixture +def dca_position() -> Position: + """A position that has been DCA'd twice.""" + return Position( + coin="ETH", + entry_price=2000.0, + quantity=3.0, + cost_basis_usd=5500.0, + dca_count=2, + dca_timestamps=[1700000000.0, 1700050000.0], + ) + + +@pytest.fixture +def trailing_position() -> Position: + """A position with trailing profit margin active.""" + return Position( + coin="BTC", + entry_price=42000.0, + quantity=0.01, + cost_basis_usd=420.0, + trailing_active=True, + trailing_peak=45000.0, + trailing_line=44775.0, # peak - 0.5% trailing gap + ) + + +# --------------------------------------------------------------------------- +# Construction & mutability +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_fields_stored(self, fresh_position: Position) -> None: + assert fresh_position.coin == "BTC" + assert fresh_position.entry_price == 42000.0 + assert fresh_position.quantity == 0.01 + assert fresh_position.cost_basis_usd == 420.0 + assert fresh_position.dca_count == 0 + assert fresh_position.dca_timestamps == [] + assert fresh_position.trailing_active is False + assert fresh_position.trailing_peak == 0.0 + assert fresh_position.trailing_line == 0.0 + + def test_mutable(self, fresh_position: Position) -> None: + """Positions are mutable — the trader updates them in-place.""" + fresh_position.quantity = 0.02 + assert fresh_position.quantity == 0.02 + + def test_defaults(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=1.0) + assert p.cost_basis_usd == 0.0 + assert p.dca_count == 0 + assert p.dca_timestamps == [] + assert p.trailing_active is False + + +# --------------------------------------------------------------------------- +# Derived properties +# --------------------------------------------------------------------------- + + +class TestAvgPrice: + def test_basic(self, fresh_position: Position) -> None: + # 420 / 0.01 = 42000 + assert fresh_position.avg_price == pytest.approx(42000.0) + + def test_after_dca(self, dca_position: Position) -> None: + # 5500 / 3.0 ≈ 1833.33 + assert dca_position.avg_price == pytest.approx(1833.333, rel=1e-3) + + def test_zero_quantity(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=0.0, cost_basis_usd=0.0) + assert p.avg_price == 0.0 + + +class TestHasDCA: + def test_no_dca(self, fresh_position: Position) -> None: + assert fresh_position.has_dca is False + + def test_with_dca(self, dca_position: Position) -> None: + assert dca_position.has_dca is True + + +class TestPnlPct: + def test_profitable(self, fresh_position: Position) -> None: + # avg = 42000, current = 44100 → (44100-42000)/42000*100 = 5.0% + assert fresh_position.pnl_pct(44100.0) == pytest.approx(5.0) + + def test_at_loss(self, fresh_position: Position) -> None: + # avg = 42000, current = 39900 → (39900-42000)/42000*100 = -5.0% + assert fresh_position.pnl_pct(39900.0) == pytest.approx(-5.0) + + def test_breakeven(self, fresh_position: Position) -> None: + assert fresh_position.pnl_pct(42000.0) == pytest.approx(0.0) + + def test_zero_avg(self) -> None: + p = Position(coin="BTC", entry_price=0.0, quantity=0.0, cost_basis_usd=0.0) + assert p.pnl_pct(100.0) == 0.0 + + def test_dca_lowers_avg(self, dca_position: Position) -> None: + """After DCA, avg is lower so same price gives higher PnL %.""" + # avg ≈ 1833.33, current = 2000 → (2000-1833.33)/1833.33*100 ≈ 9.09% + assert dca_position.pnl_pct(2000.0) == pytest.approx(9.0909, rel=1e-3) + + +class TestMarketValue: + def test_basic(self, fresh_position: Position) -> None: + # 0.01 * 44000 = 440 + assert fresh_position.market_value(44000.0) == pytest.approx(440.0) + + def test_zero_quantity(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=0.0) + assert p.market_value(50000.0) == 0.0 + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_valid_position(self, fresh_position: Position) -> None: + assert fresh_position.validate() == [] + + def test_empty_coin(self) -> None: + p = Position(coin="", entry_price=100.0, quantity=1.0) + errors = p.validate() + assert any("coin" in e for e in errors) + + def test_negative_entry_price(self) -> None: + p = Position(coin="BTC", entry_price=-1.0, quantity=1.0) + errors = p.validate() + assert any("entry_price" in e for e in errors) + + def test_negative_quantity(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=-1.0) + errors = p.validate() + assert any("quantity" in e for e in errors) + + def test_negative_cost_basis(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=1.0, cost_basis_usd=-1.0) + errors = p.validate() + assert any("cost_basis_usd" in e for e in errors) + + def test_negative_dca_count(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=1.0, dca_count=-1) + errors = p.validate() + assert any("dca_count" in e for e in errors) + + def test_negative_trailing_peak(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=1.0, trailing_peak=-1.0) + errors = p.validate() + assert any("trailing_peak" in e for e in errors) + + def test_negative_trailing_line(self) -> None: + p = Position(coin="BTC", entry_price=100.0, quantity=1.0, trailing_line=-1.0) + errors = p.validate() + assert any("trailing_line" in e for e in errors) + + def test_zero_values_valid(self) -> None: + """Zero is valid for numeric fields.""" + p = Position(coin="BTC", entry_price=0.0, quantity=0.0, cost_basis_usd=0.0) + assert p.validate() == [] diff --git a/tests/unit/models/test_signal.py b/tests/unit/models/test_signal.py new file mode 100644 index 000000000..6771b81f8 --- /dev/null +++ b/tests/unit/models/test_signal.py @@ -0,0 +1,191 @@ +"""Tests for powertrader.models.signal.""" + +from __future__ import annotations + +import pytest + +from powertrader.models.signal import NUM_TIMEFRAMES, Signal + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def entry_signal() -> Signal: + """A signal that meets entry criteria: long >= 3, short == 0.""" + return Signal( + coin="BTC", + long_level=5, + short_level=0, + long_bounds=[95.0, 93.0, 90.0, 88.0, 85.0, 80.0, 75.0], + short_bounds=[110.0, 112.0, 115.0, 118.0, 120.0, 125.0, 130.0], + long_profit_margin=5.0, + short_profit_margin=3.0, + timestamp=1700000000.0, + ) + + +@pytest.fixture +def neutral_signal() -> Signal: + """A signal with no conviction in either direction.""" + return Signal( + coin="ETH", + long_level=0, + short_level=0, + timestamp=1700000000.0, + ) + + +@pytest.fixture +def mixed_signal() -> Signal: + """A signal with both long and short levels set.""" + return Signal( + coin="XRP", + long_level=4, + short_level=2, + timestamp=1700000000.0, + ) + + +# --------------------------------------------------------------------------- +# Construction & immutability +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_fields_stored(self, entry_signal: Signal) -> None: + assert entry_signal.coin == "BTC" + assert entry_signal.long_level == 5 + assert entry_signal.short_level == 0 + assert len(entry_signal.long_bounds) == 7 + assert len(entry_signal.short_bounds) == 7 + assert entry_signal.long_profit_margin == 5.0 + assert entry_signal.short_profit_margin == 3.0 + assert entry_signal.timestamp == 1700000000.0 + + def test_defaults(self) -> None: + s = Signal(coin="BTC") + assert s.long_level == 0 + assert s.short_level == 0 + assert s.long_bounds == [] + assert s.short_bounds == [] + assert s.long_profit_margin == 0.0 + assert s.short_profit_margin == 0.0 + assert s.timestamp == 0.0 + + def test_frozen(self, entry_signal: Signal) -> None: + with pytest.raises(AttributeError): + entry_signal.long_level = 7 # type: ignore[misc] + + def test_equality(self) -> None: + a = Signal(coin="BTC", long_level=3) + b = Signal(coin="BTC", long_level=3) + assert a == b + + +# --------------------------------------------------------------------------- +# Convenience properties +# --------------------------------------------------------------------------- + + +class TestConvenienceProperties: + def test_is_long_entry_true(self, entry_signal: Signal) -> None: + assert entry_signal.is_long_entry is True + + def test_is_long_entry_false_low_level(self) -> None: + s = Signal(coin="BTC", long_level=2, short_level=0) + assert s.is_long_entry is False + + def test_is_long_entry_false_has_short(self) -> None: + s = Signal(coin="BTC", long_level=5, short_level=1) + assert s.is_long_entry is False + + def test_is_long_entry_boundary(self) -> None: + """Level 3 is the minimum for entry.""" + s = Signal(coin="BTC", long_level=3, short_level=0) + assert s.is_long_entry is True + + def test_is_neutral(self, neutral_signal: Signal) -> None: + assert neutral_signal.is_neutral is True + + def test_is_neutral_false(self, entry_signal: Signal) -> None: + assert entry_signal.is_neutral is False + + def test_mixed_signal_not_entry(self, mixed_signal: Signal) -> None: + """Long >= 3 but short > 0 means no entry.""" + assert mixed_signal.is_long_entry is False + assert mixed_signal.is_neutral is False + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_valid_signal(self, entry_signal: Signal) -> None: + assert entry_signal.validate() == [] + + def test_valid_minimal(self) -> None: + s = Signal(coin="BTC") + assert s.validate() == [] + + def test_empty_coin(self) -> None: + s = Signal(coin="") + errors = s.validate() + assert any("coin" in e for e in errors) + + def test_long_level_too_low(self) -> None: + s = Signal(coin="BTC", long_level=-1) + errors = s.validate() + assert any("long_level" in e for e in errors) + + def test_long_level_too_high(self) -> None: + s = Signal(coin="BTC", long_level=8) + errors = s.validate() + assert any("long_level" in e for e in errors) + + def test_short_level_too_low(self) -> None: + s = Signal(coin="BTC", short_level=-1) + errors = s.validate() + assert any("short_level" in e for e in errors) + + def test_short_level_too_high(self) -> None: + s = Signal(coin="BTC", short_level=8) + errors = s.validate() + assert any("short_level" in e for e in errors) + + def test_wrong_long_bounds_length(self) -> None: + s = Signal(coin="BTC", long_bounds=[1.0, 2.0, 3.0]) + errors = s.validate() + assert any("long_bounds" in e for e in errors) + + def test_wrong_short_bounds_length(self) -> None: + s = Signal(coin="BTC", short_bounds=[1.0] * 5) + errors = s.validate() + assert any("short_bounds" in e for e in errors) + + def test_correct_bounds_length(self) -> None: + s = Signal( + coin="BTC", + long_bounds=[1.0] * NUM_TIMEFRAMES, + short_bounds=[2.0] * NUM_TIMEFRAMES, + ) + assert s.validate() == [] + + def test_empty_bounds_valid(self) -> None: + s = Signal(coin="BTC", long_bounds=[], short_bounds=[]) + assert s.validate() == [] + + def test_negative_timestamp(self) -> None: + s = Signal(coin="BTC", timestamp=-1.0) + errors = s.validate() + assert any("timestamp" in e for e in errors) + + def test_boundary_levels_valid(self) -> None: + """Levels 0 and 7 are both valid.""" + s0 = Signal(coin="BTC", long_level=0, short_level=0) + s7 = Signal(coin="BTC", long_level=7, short_level=7) + assert s0.validate() == [] + assert s7.validate() == [] diff --git a/tests/unit/models/test_trade.py b/tests/unit/models/test_trade.py new file mode 100644 index 000000000..bb13fa039 --- /dev/null +++ b/tests/unit/models/test_trade.py @@ -0,0 +1,322 @@ +"""Tests for powertrader.models.trade.""" + +from __future__ import annotations + +import pytest + +from powertrader.models.trade import Trade + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def entry_buy() -> Trade: + """A standard entry buy trade.""" + return Trade( + coin="BTC", + side="BUY", + price=42000.0, + quantity=0.01, + value=420.0, + reason="entry", + timestamp=1700000000.0, + fees_usd=1.05, + order_id="abc123", + ) + + +@pytest.fixture +def dca_buy() -> Trade: + """A DCA buy trade.""" + return Trade( + coin="ETH", + side="BUY", + price=1800.0, + quantity=1.5, + value=2700.0, + reason="dca_stage_3", + timestamp=1700050000.0, + ) + + +@pytest.fixture +def exit_sell() -> Trade: + """A trailing exit sell trade.""" + return Trade( + coin="BTC", + side="SELL", + price=44100.0, + quantity=0.01, + value=441.0, + reason="trailing_exit", + timestamp=1700100000.0, + pnl_pct=5.0, + fees_usd=1.10, + order_id="xyz789", + ) + + +# --------------------------------------------------------------------------- +# Construction & immutability +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_fields_stored(self, entry_buy: Trade) -> None: + assert entry_buy.coin == "BTC" + assert entry_buy.side == "BUY" + assert entry_buy.price == 42000.0 + assert entry_buy.quantity == 0.01 + assert entry_buy.value == 420.0 + assert entry_buy.reason == "entry" + assert entry_buy.timestamp == 1700000000.0 + assert entry_buy.pnl_pct is None + assert entry_buy.fees_usd == 1.05 + assert entry_buy.order_id == "abc123" + + def test_defaults(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + assert t.pnl_pct is None + assert t.fees_usd is None + assert t.order_id is None + + def test_frozen(self, entry_buy: Trade) -> None: + with pytest.raises(AttributeError): + entry_buy.price = 999.0 # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Convenience properties +# --------------------------------------------------------------------------- + + +class TestConvenience: + def test_is_buy(self, entry_buy: Trade) -> None: + assert entry_buy.is_buy is True + assert entry_buy.is_sell is False + + def test_is_sell(self, exit_sell: Trade) -> None: + assert exit_sell.is_sell is True + assert exit_sell.is_buy is False + + def test_is_dca_entry(self, entry_buy: Trade) -> None: + assert entry_buy.is_dca is False + + def test_is_dca_true(self, dca_buy: Trade) -> None: + assert dca_buy.is_dca is True + + def test_is_dca_various_stages(self) -> None: + for stage in range(1, 8): + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=1.0, + value=100.0, + reason=f"dca_stage_{stage}", + timestamp=0.0, + ) + assert t.is_dca is True + + +# --------------------------------------------------------------------------- +# Serialisation +# --------------------------------------------------------------------------- + + +class TestToDict: + def test_round_trip_keys(self, entry_buy: Trade) -> None: + d = entry_buy.to_dict() + assert d["ts"] == 1700000000.0 + assert d["side"] == "buy" # lowercase in serialised form + assert d["tag"] == "entry" + assert d["symbol"] == "BTC" + assert d["qty"] == 0.01 + assert d["price"] == 42000.0 + assert d["fees_usd"] == 1.05 + assert d["order_id"] == "abc123" + assert d["pnl_pct"] is None + + def test_sell_pnl(self, exit_sell: Trade) -> None: + d = exit_sell.to_dict() + assert d["pnl_pct"] == 5.0 + assert d["side"] == "sell" + + +class TestFromDict: + def test_from_new_schema(self) -> None: + data = { + "coin": "BTC", + "side": "BUY", + "price": 42000.0, + "quantity": 0.01, + "value": 420.0, + "reason": "entry", + "timestamp": 1700000000.0, + "pnl_pct": None, + "fees_usd": 1.05, + "order_id": "abc123", + } + t = Trade.from_dict(data) + assert t.coin == "BTC" + assert t.side == "BUY" + assert t.price == 42000.0 + assert t.quantity == 0.01 + assert t.reason == "entry" + assert t.timestamp == 1700000000.0 + assert t.pnl_pct is None + assert t.fees_usd == 1.05 + + def test_from_legacy_schema(self) -> None: + """Legacy trade_history.jsonl format: symbol, ts, tag, side lowercase.""" + data = { + "symbol": "BTCUSDT", + "side": "buy", + "price": 42000.0, + "qty": 0.01, + "tag": "dca_stage_1", + "ts": 1700000000.0, + "pnl_pct": None, + } + t = Trade.from_dict(data) + assert t.coin == "BTCUSDT" + assert t.side == "BUY" + assert t.quantity == 0.01 + assert t.reason == "dca_stage_1" + assert t.timestamp == 1700000000.0 + + def test_from_dict_missing_optional(self) -> None: + data = { + "coin": "ETH", + "side": "SELL", + "price": 2000.0, + "qty": 1.0, + "value": 2000.0, + "reason": "trailing_exit", + "timestamp": 1700000000.0, + } + t = Trade.from_dict(data) + assert t.pnl_pct is None + assert t.fees_usd is None + assert t.order_id is None + + def test_roundtrip(self, entry_buy: Trade) -> None: + """to_dict → from_dict should preserve key fields.""" + d = entry_buy.to_dict() + reconstructed = Trade.from_dict(d) + assert reconstructed.coin == entry_buy.coin + assert reconstructed.side == entry_buy.side.upper() + assert reconstructed.price == entry_buy.price + assert reconstructed.quantity == entry_buy.quantity + assert reconstructed.reason == entry_buy.reason + assert reconstructed.timestamp == entry_buy.timestamp + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + def test_valid_trade(self, entry_buy: Trade) -> None: + assert entry_buy.validate() == [] + + def test_empty_coin(self) -> None: + t = Trade( + coin="", + side="BUY", + price=100.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + errors = t.validate() + assert any("coin" in e for e in errors) + + def test_invalid_side(self) -> None: + t = Trade( + coin="BTC", + side="HOLD", + price=100.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + errors = t.validate() + assert any("side" in e for e in errors) + + def test_negative_price(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=-1.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + errors = t.validate() + assert any("price" in e for e in errors) + + def test_negative_quantity(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=-1.0, + value=100.0, + reason="entry", + timestamp=0.0, + ) + errors = t.validate() + assert any("quantity" in e for e in errors) + + def test_negative_value(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=1.0, + value=-100.0, + reason="entry", + timestamp=0.0, + ) + errors = t.validate() + assert any("value" in e for e in errors) + + def test_negative_timestamp(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=100.0, + quantity=1.0, + value=100.0, + reason="entry", + timestamp=-1.0, + ) + errors = t.validate() + assert any("timestamp" in e for e in errors) + + def test_zero_values_valid(self) -> None: + t = Trade( + coin="BTC", + side="BUY", + price=0.0, + quantity=0.0, + value=0.0, + reason="entry", + timestamp=0.0, + ) + assert t.validate() == [] diff --git a/tests/unit/models/test_types.py b/tests/unit/models/test_types.py new file mode 100644 index 000000000..3afa859e1 --- /dev/null +++ b/tests/unit/models/test_types.py @@ -0,0 +1,25 @@ +"""Tests for powertrader.models.types.""" + +from __future__ import annotations + +from powertrader.models.types import CoinSymbol, PriceLevel, SignalLevel, Timeframe + + +class TestTypeAliases: + """Type aliases are just documentation — verify they're importable and usable.""" + + def test_timeframe_is_str(self) -> None: + tf: Timeframe = "1hour" + assert isinstance(tf, str) + + def test_coin_symbol_is_str(self) -> None: + coin: CoinSymbol = "BTC" + assert isinstance(coin, str) + + def test_signal_level_is_int(self) -> None: + level: SignalLevel = 5 + assert isinstance(level, int) + + def test_price_level_is_float(self) -> None: + price: PriceLevel = 42000.50 + assert isinstance(price, float)