diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ae5fd92 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Python +__pycache__/ +*.py[cod] +*.egg +*.egg-info/ +*.pyo + +# Virtual environments +venv/ +.env + +# Jupyter notebooks +.ipynb_checkpoints + +# Data +*.csv +*.h5 +*.parquet + +# OS files +.DS_Store + diff --git a/README.md b/README.md index 0778f7d..d655185 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,27 @@ -# pairtrader -Stastically Arbitrage opportunity may be exploided via RL/DL. Stay tunned for more +# Pair Trading with Reinforcement Learning + +This repository contains a toy framework for experimenting with pair trading using reinforcement learning. It provides helpers for fetching price data, a simple trading environment and training scripts for three different RL algorithms. The goal is to showcase the basic workflow rather than provide a production-ready solution. + +## Setup + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Usage + +The main entry point is `main.py`. It downloads daily close prices (with dividends and splits auto-adjusted) for 20 S&P 500 tickers. The prices are saved under `data/` using a file name such as `S&P-2015-01-01--2020-01-01.csv`. The script prints the correlation matrix of the prices and then trains RL agents (PPO, DQN and SAC). + +```bash +python main.py --start 2015-01-01 --end 2020-01-01 +``` + +The code is for educational use. It is not production ready and should be extended for a full dissertation. + +Data is downloaded using `yfinance`. If internet access is unavailable the script +falls back to a small set of hard coded tickers. + +Only the adjusted close prices are used when computing the stock correlation matrix. + diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/ddqn_agent.py b/agents/ddqn_agent.py new file mode 100644 index 0000000..b406ae4 --- /dev/null +++ b/agents/ddqn_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import DQN +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_dqn(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> DQN: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = DQN("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py new file mode 100644 index 0000000..196de85 --- /dev/null +++ b/agents/ppo_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_ppo(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> PPO: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = PPO("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/agents/sac_agent.py b/agents/sac_agent.py new file mode 100644 index 0000000..e48d3aa --- /dev/null +++ b/agents/sac_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import SAC +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_sac(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> SAC: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = SAC("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/download_data.py b/data/download_data.py new file mode 100644 index 0000000..b0879fd --- /dev/null +++ b/data/download_data.py @@ -0,0 +1,48 @@ +import yfinance as yf +import pandas as pd +import requests +from typing import List +import io + + +def get_sp500_tickers() -> List[str]: + """Return a list of S&P 500 tickers. + + Attempts to fetch the component list from Wikipedia. If the request fails + (e.g. due to lack of internet access), a fallback list of large cap stocks + is returned. + """ + url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" + try: + html = requests.get(url, timeout=10).text + tables = pd.read_html(io.StringIO(html), flavor="bs4") + tickers = tables[0]["Symbol"].tolist() + tickers = [t.replace(".", "-") for t in tickers] + except Exception: + # Fallback list (20 tickers) if network access is not available + tickers = [ + "AAPL", "MSFT", "GOOGL", "AMZN", "META", + "NVDA", "BRK-B", "JPM", "JNJ", "V", + "PG", "XOM", "UNH", "HD", "MA", + "CVX", "LLY", "MRK", "ABBV", "PEP", + ] + return tickers[:20] + + +def download_price_history(tickers: List[str], start: str, end: str) -> pd.DataFrame: + """Download adjusted close prices for the given tickers.""" + try: + data = yf.download( + tickers, + start=start, + end=end, + auto_adjust=True, + progress=False, + threads=False, + )["Close"] + except Exception as exc: + raise RuntimeError("Failed to download price data") from exc + + data = data.dropna(axis=0, how="all") + return data + diff --git a/main.py b/main.py new file mode 100644 index 0000000..5ecd3c3 --- /dev/null +++ b/main.py @@ -0,0 +1,51 @@ +import argparse +import pandas as pd +from sklearn.preprocessing import StandardScaler +from data.download_data import get_sp500_tickers, download_price_history +from agents.ppo_agent import train_ppo +from agents.ddqn_agent import train_dqn +from agents.sac_agent import train_sac +from pathlib import Path + + +def main(start: str, end: str): + tickers = get_sp500_tickers() + prices = download_price_history(tickers, start, end) + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + filename = data_dir / f"S&P-{start}--{end}.csv" + prices.to_csv(filename) + + # Pair selection based on correlation + corr = prices.corr() + print("Correlation matrix:") + print(corr) + corr_pairs = ( + corr.stack(future_stack=True) + .reset_index(name="corr") + .rename(columns={"level_0": "stock1", "level_1": "stock2"}) + ) + corr_pairs = corr_pairs[corr_pairs["stock1"] != corr_pairs["stock2"]] + corr_pairs = corr_pairs.sort_values(by="corr", ascending=False) + pair = (corr_pairs.iloc[0]["stock1"], corr_pairs.iloc[0]["stock2"]) + print(f"Selected pair: {pair}") + + # Scale spreads for environment + scaler = StandardScaler() + prices[pair] = scaler.fit_transform(prices[pair]) + + print("Training PPO...") + train_ppo(prices, pair, timesteps=5000) + print("Training DQN...") + train_dqn(prices, pair, timesteps=5000) + print("Training SAC...") + train_sac(prices, pair, timesteps=5000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--start", type=str, required=True) + parser.add_argument("--end", type=str, required=True) + args = parser.parse_args() + main(args.start, args.end) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4e0e390 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +pandas +numpy +matplotlib +yfinance +stable-baselines3 +gymnasium +scikit-learn +requests +torch diff --git a/trading/__init__.py b/trading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trading/pair_trading_env.py b/trading/pair_trading_env.py new file mode 100644 index 0000000..11af13b --- /dev/null +++ b/trading/pair_trading_env.py @@ -0,0 +1,51 @@ +import numpy as np +import pandas as pd +import gymnasium as gym +from gymnasium import spaces + + +class PairTradingEnv(gym.Env): + """A simple pair trading environment for RL.""" + + def __init__(self, prices: pd.DataFrame, pair: tuple[str, str]): + super().__init__() + self.prices = prices[list(pair)].dropna() + self.pair = pair + self.current_step = 0 + + # Observation: price spread + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,)) + # Actions: 0 -> hold, 1 -> long first/short second, 2 -> short first/long second + self.action_space = spaces.Discrete(3) + self.position = 0 # -1 short pair, 0 flat, 1 long pair + self.entry_price = 0.0 + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + self.current_step = 0 + self.position = 0 + self.entry_price = 0.0 + return self._get_obs(), {} + + def _get_obs(self): + spread = self.prices.iloc[self.current_step, 0] - self.prices.iloc[self.current_step, 1] + return np.array([spread], dtype=np.float32) + + def step(self, action: int): + done = False + reward = 0.0 + if action == 1 and self.position == 0: # long first, short second + self.position = 1 + self.entry_price = self._get_obs()[0] + elif action == 2 and self.position == 0: # short first, long second + self.position = -1 + self.entry_price = self._get_obs()[0] + elif action == 0 and self.position != 0: # close position + reward = self.position * (self.entry_price - self._get_obs()[0]) + self.position = 0 + + self.current_step += 1 + if self.current_step >= len(self.prices) - 1: + done = True + return self._get_obs(), reward, done, False, {} +