Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions plan.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,13 @@ PriceLevel: TypeAlias = float
```

**Phase 2 Deliverables:**
- [ ] `models/candle.py`
- [ ] `models/signal.py`
- [ ] `models/position.py`
- [ ] `models/trade.py`
- [ ] `models/memory.py`
- [ ] `models/types.py`
- [ ] Unit tests for all model validation and properties
- [x] `models/candle.py`
- [x] `models/signal.py`
- [x] `models/position.py`
- [x] `models/trade.py`
- [x] `models/memory.py`
- [x] `models/types.py`
- [x] Unit tests for all model validation and properties (138 tests)

---

Expand Down
25 changes: 25 additions & 0 deletions src/powertrader/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Domain data models for PowerTrader AI.

Re-exports all model classes for convenient imports::

from powertrader.models import Candle, Signal, Position, Trade, PatternMemory
"""

from powertrader.models.candle import Candle
from powertrader.models.memory import PatternMemory
from powertrader.models.position import Position
from powertrader.models.signal import Signal
from powertrader.models.trade import Trade
from powertrader.models.types import CoinSymbol, PriceLevel, SignalLevel, Timeframe

__all__ = [
"Candle",
"CoinSymbol",
"PatternMemory",
"Position",
"PriceLevel",
"Signal",
"SignalLevel",
"Timeframe",
"Trade",
]
119 changes: 119 additions & 0 deletions src/powertrader/models/candle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""OHLCV candle data model.

Represents a single candlestick bar as returned by market data APIs
(KuCoin klines). Immutable so it can be safely shared and cached.
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True, slots=True)
class Candle:
"""A single OHLCV candlestick bar.

Parameters
----------
timestamp:
Candle open time as a Unix epoch in **seconds**.
open, high, low, close:
Price values for the bar.
volume:
Traded volume in the base asset during this bar.
"""

timestamp: int
open: float
high: float
low: float
close: float
volume: float

# -- derived properties ---------------------------------------------------

@property
def body_pct(self) -> float:
"""Percentage change from open to close: ``(close - open) / open * 100``.

Returns ``0.0`` if *open* is zero (degenerate candle).
"""
if self.open == 0.0:
return 0.0
return (self.close - self.open) / self.open * 100.0

@property
def range_pct(self) -> float:
"""Total bar range as a percentage of the low: ``(high - low) / low * 100``.

Returns ``0.0`` if *low* is zero.
"""
if self.low == 0.0:
return 0.0
return (self.high - self.low) / self.low * 100.0

@property
def upper_shadow_pct(self) -> float:
"""Upper shadow as a percentage of *open*: ``(high - max(open, close)) / open * 100``.

Returns ``0.0`` if *open* is zero.
"""
if self.open == 0.0:
return 0.0
top = max(self.open, self.close)
return (self.high - top) / self.open * 100.0

@property
def lower_shadow_pct(self) -> float:
"""Lower shadow as a percentage of *open*: ``(min(open, close) - low) / open * 100``.

Returns ``0.0`` if *open* is zero.
"""
if self.open == 0.0:
return 0.0
bottom = min(self.open, self.close)
return (bottom - self.low) / self.open * 100.0

@property
def is_bullish(self) -> bool:
"""``True`` if the close is strictly above the open."""
return self.close > self.open

@property
def is_bearish(self) -> bool:
"""``True`` if the close is strictly below the open."""
return self.close < self.open

@property
def mid(self) -> float:
"""Midpoint price: ``(high + low) / 2``."""
return (self.high + self.low) / 2.0

# -- validation -----------------------------------------------------------

def validate(self) -> list[str]:
"""Return a list of validation errors (empty means valid)."""
errors: list[str] = []
if self.timestamp < 0:
errors.append(f"timestamp={self.timestamp} must be >= 0.")
if self.open < 0:
errors.append(f"open={self.open} must be >= 0.")
if self.high < 0:
errors.append(f"high={self.high} must be >= 0.")
if self.low < 0:
errors.append(f"low={self.low} must be >= 0.")
if self.close < 0:
errors.append(f"close={self.close} must be >= 0.")
if self.volume < 0:
errors.append(f"volume={self.volume} must be >= 0.")
if self.high < self.low:
errors.append(f"high={self.high} must be >= low={self.low}.")
if self.high < self.open:
errors.append(f"high={self.high} must be >= open={self.open}.")
if self.high < self.close:
errors.append(f"high={self.high} must be >= close={self.close}.")
if self.low > self.open:
errors.append(f"low={self.low} must be <= open={self.open}.")
if self.low > self.close:
errors.append(f"low={self.low} must be <= close={self.close}.")
return errors
200 changes: 200 additions & 0 deletions src/powertrader/models/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Pattern memory data model.

A :class:`PatternMemory` holds the trained price-pattern data for a
single coin on a single timeframe. The trainer builds these from
historical kline data, and the thinker reads them to generate signals.

File format
-----------
``memories_<tf>.txt`` uses a custom delimited text format:

- Patterns are separated by ``~``
- Each pattern has three fields separated by ``{}``::

candle_pcts{}high_diff{}low_diff

where *candle_pcts* is a space-separated sequence of percentage
changes, *high_diff* is the predicted-high deviation, and *low_diff*
is the predicted-low deviation.

Parallel weight files (``memory_weights_<tf>.txt``, etc.) contain
space-separated floats — one weight per pattern.
"""

from __future__ import annotations

from dataclasses import dataclass, field

# Delimiters used in the on-disk memory format.
PATTERN_SEPARATOR: str = "~"
FIELD_SEPARATOR: str = "{}"


@dataclass(slots=True)
class PatternMemory:
"""Trained pattern memory for one coin / one timeframe.

Parameters
----------
patterns:
List of patterns. Each pattern is a list of floats representing
the candle-body percentage changes that define the shape.
high_diffs:
Predicted-high deviation for each pattern (parallel to *patterns*).
low_diffs:
Predicted-low deviation for each pattern (parallel to *patterns*).
weights:
Reliability weight for the base prediction of each pattern.
weights_high:
Reliability weight for the high prediction of each pattern.
weights_low:
Reliability weight for the low prediction of each pattern.
threshold:
Maximum distance for a current candle sequence to "match" a
stored pattern.
"""

patterns: list[list[float]] = field(default_factory=list)
high_diffs: list[float] = field(default_factory=list)
low_diffs: list[float] = field(default_factory=list)
weights: list[float] = field(default_factory=list)
weights_high: list[float] = field(default_factory=list)
weights_low: list[float] = field(default_factory=list)
threshold: float = 1.0

# -- derived properties ---------------------------------------------------

@property
def size(self) -> int:
"""Number of stored patterns."""
return len(self.patterns)

@property
def is_empty(self) -> bool:
"""``True`` if no patterns have been stored."""
return len(self.patterns) == 0

# -- serialisation (on-disk format) ---------------------------------------

def to_memory_text(self) -> str:
"""Serialise patterns to the ``memories_<tf>.txt`` format.

Each pattern is rendered as::

candle_pct1 candle_pct2{}high_diff{}low_diff

Patterns are joined by ``~``.
"""
parts: list[str] = []
for i, pat in enumerate(self.patterns):
candle_str = " ".join(str(v) for v in pat)
high = self.high_diffs[i] if i < len(self.high_diffs) else 0.0
low = self.low_diffs[i] if i < len(self.low_diffs) else 0.0
parts.append(f"{candle_str}{FIELD_SEPARATOR}{high}{FIELD_SEPARATOR}{low}")
return PATTERN_SEPARATOR.join(parts)

@classmethod
def from_memory_text(
cls,
text: str,
weights_text: str = "",
weights_high_text: str = "",
weights_low_text: str = "",
threshold: float = 1.0,
) -> PatternMemory:
"""Parse from the on-disk ``memories_<tf>.txt`` format.

Parameters
----------
text:
Contents of ``memories_<tf>.txt``.
weights_text:
Contents of ``memory_weights_<tf>.txt`` (space-separated floats).
weights_high_text:
Contents of ``memory_weights_high_<tf>.txt``.
weights_low_text:
Contents of ``memory_weights_low_<tf>.txt``.
threshold:
Value from ``neural_perfect_threshold_<tf>.txt``.
"""
patterns: list[list[float]] = []
high_diffs: list[float] = []
low_diffs: list[float] = []

raw_patterns = text.strip().split(PATTERN_SEPARATOR) if text.strip() else []
for raw in raw_patterns:
raw = raw.strip()
if not raw:
continue
fields = raw.split(FIELD_SEPARATOR)
# Field 0: candle percentages (space-separated)
candle_pcts = _parse_floats_space(fields[0]) if fields else []
if not candle_pcts:
continue
patterns.append(candle_pcts)
# Field 1: high_diff
high_diffs.append(_safe_float(fields[1]) if len(fields) > 1 else 0.0)
# Field 2: low_diff
low_diffs.append(_safe_float(fields[2]) if len(fields) > 2 else 0.0)

return cls(
patterns=patterns,
high_diffs=high_diffs,
low_diffs=low_diffs,
weights=_parse_floats_space(weights_text),
weights_high=_parse_floats_space(weights_high_text),
weights_low=_parse_floats_space(weights_low_text),
threshold=threshold,
)

# -- validation -----------------------------------------------------------

def validate(self) -> list[str]:
"""Return a list of validation errors (empty means valid)."""
errors: list[str] = []
n = len(self.patterns)
if len(self.high_diffs) != n:
errors.append(f"high_diffs length ({len(self.high_diffs)}) != patterns length ({n}).")
if len(self.low_diffs) != n:
errors.append(f"low_diffs length ({len(self.low_diffs)}) != patterns length ({n}).")
if self.weights and len(self.weights) != n:
errors.append(f"weights length ({len(self.weights)}) != patterns length ({n}).")
if self.weights_high and len(self.weights_high) != n:
errors.append(
f"weights_high length ({len(self.weights_high)}) != patterns length ({n})."
)
if self.weights_low and len(self.weights_low) != n:
errors.append(
f"weights_low length ({len(self.weights_low)}) != patterns length ({n})."
)
if self.threshold < 0:
errors.append(f"threshold={self.threshold} must be >= 0.")
return errors


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _parse_floats_space(text: str) -> list[float]:
"""Parse space-separated floats, skipping blanks."""
if not text or not text.strip():
return []
result: list[float] = []
for tok in text.strip().split():
tok = tok.strip()
if tok:
try:
result.append(float(tok))
except ValueError:
continue
return result


def _safe_float(text: str) -> float:
"""Parse a single float, returning ``0.0`` on failure."""
try:
return float(text.strip())
except (ValueError, AttributeError):
return 0.0
Loading