From 834c3d745f67bc9d2d005039cb5951b20eaad7ce Mon Sep 17 00:00:00 2001 From: Erdem Arslan <69273456+Razbolt@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:07:33 +0100 Subject: [PATCH 1/3] Improve data download and pair selection --- .gitignore | 22 ++++++++++++++++ README.md | 24 +++++++++++++++-- agents/__init__.py | 0 agents/ddqn_agent.py | 12 +++++++++ agents/ppo_agent.py | 12 +++++++++ agents/sac_agent.py | 12 +++++++++ data/__init__.py | 0 data/download_data.py | 41 +++++++++++++++++++++++++++++ main.py | 48 ++++++++++++++++++++++++++++++++++ requirements.txt | 8 ++++++ trading/__init__.py | 0 trading/pair_trading_env.py | 51 +++++++++++++++++++++++++++++++++++++ 12 files changed, 228 insertions(+), 2 deletions(-) create mode 100644 .gitignore create mode 100644 agents/__init__.py create mode 100644 agents/ddqn_agent.py create mode 100644 agents/ppo_agent.py create mode 100644 agents/sac_agent.py create mode 100644 data/__init__.py create mode 100644 data/download_data.py create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 trading/__init__.py create mode 100644 trading/pair_trading_env.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ae5fd92 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Python +__pycache__/ +*.py[cod] +*.egg +*.egg-info/ +*.pyo + +# Virtual environments +venv/ +.env + +# Jupyter notebooks +.ipynb_checkpoints + +# Data +*.csv +*.h5 +*.parquet + +# OS files +.DS_Store + diff --git a/README.md b/README.md index 0778f7d..fd8cbcf 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,22 @@ -# pairtrader -Stastically Arbitrage opportunity may be exploided via RL/DL. Stay tunned for more +# Pair Trading with Reinforcement Learning + +This project demonstrates a simple research framework for creating a pair trading strategy using reinforcement learning. The repository contains utilities to download data, define a trading environment and train several RL agents. + +## Setup + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +## Usage + +The main entry point is `main.py`. It downloads historical data from Yahoo Finance for 20 S&P 500 tickers, prints their correlation matrix and then trains RL agents (PPO, DQN and SAC). + +```bash +python main.py --start 2015-01-01 --end 2020-01-01 +``` + +The code is for educational use. It is not production ready and should be extended for a full dissertation. + diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/ddqn_agent.py b/agents/ddqn_agent.py new file mode 100644 index 0000000..b406ae4 --- /dev/null +++ b/agents/ddqn_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import DQN +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_dqn(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> DQN: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = DQN("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py new file mode 100644 index 0000000..196de85 --- /dev/null +++ b/agents/ppo_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_ppo(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> PPO: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = PPO("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/agents/sac_agent.py b/agents/sac_agent.py new file mode 100644 index 0000000..e48d3aa --- /dev/null +++ b/agents/sac_agent.py @@ -0,0 +1,12 @@ +from stable_baselines3 import SAC +from stable_baselines3.common.vec_env import DummyVecEnv +from trading.pair_trading_env import PairTradingEnv +import pandas as pd + + +def train_sac(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> SAC: + env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)]) + model = SAC("MlpPolicy", env, verbose=0) + model.learn(total_timesteps=timesteps) + return model + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/download_data.py b/data/download_data.py new file mode 100644 index 0000000..503aa94 --- /dev/null +++ b/data/download_data.py @@ -0,0 +1,41 @@ +import yfinance as yf +import pandas as pd +import requests +from typing import List + + +def get_sp500_tickers() -> List[str]: + """Return a list of S&P 500 tickers. + + Attempts to fetch the component list from Wikipedia. If the request fails + (e.g. due to lack of internet access), a fallback list of large cap stocks + is returned. + """ + url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" + try: + tables = pd.read_html(requests.get(url, timeout=10).text) + tickers = tables[0]["Symbol"].tolist() + except Exception: + # Fallback list (20 tickers) if network access is not available + tickers = [ + "AAPL", "MSFT", "GOOGL", "AMZN", "META", + "NVDA", "BRK-B", "JPM", "JNJ", "V", + "PG", "XOM", "UNH", "HD", "MA", + "CVX", "LLY", "MRK", "ABBV", "PEP", + ] + return tickers[:20] + + +def download_price_history(tickers: List[str], start: str, end: str) -> pd.DataFrame: + """Download adjusted close prices for the given tickers.""" + data = yf.download( + tickers, + start=start, + end=end, + group_by="ticker", + progress=False, + threads=False, + )["Adj Close"] + data = data.dropna(axis=0, how="all") + return data + diff --git a/main.py b/main.py new file mode 100644 index 0000000..29c2ce5 --- /dev/null +++ b/main.py @@ -0,0 +1,48 @@ +import argparse +import pandas as pd +from sklearn.preprocessing import StandardScaler +from data.download_data import get_sp500_tickers, download_price_history +from trading.pair_trading_env import PairTradingEnv +from agents.ppo_agent import train_ppo +from agents.ddqn_agent import train_dqn +from agents.sac_agent import train_sac + + +def main(start: str, end: str): + tickers = get_sp500_tickers() + prices = download_price_history(tickers, start, end) + + # Pair selection based on correlation + corr = prices.corr() + print("Correlation matrix:") + print(corr) + corr_pairs = ( + corr.stack() + .reset_index() + .rename(columns={"level_0": "stock1", "level_1": "stock2", 0: "corr"}) + ) + corr_pairs = corr_pairs[corr_pairs["stock1"] != corr_pairs["stock2"]] + corr_pairs = corr_pairs.sort_values(by="corr", ascending=False) + pair = (corr_pairs.iloc[0]["stock1"], corr_pairs.iloc[0]["stock2"]) + print(f"Selected pair: {pair}") + + # Scale spreads for environment + scaler = StandardScaler() + prices[pair] = scaler.fit_transform(prices[pair]) + + env = PairTradingEnv(prices, pair) + print("Training PPO...") + train_ppo(prices, pair, timesteps=5000) + print("Training DQN...") + train_dqn(prices, pair, timesteps=5000) + print("Training SAC...") + train_sac(prices, pair, timesteps=5000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--start", type=str, required=True) + parser.add_argument("--end", type=str, required=True) + args = parser.parse_args() + main(args.start, args.end) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5584d39 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +pandas +numpy +matplotlib +yfinance +stable-baselines3 +gymnasium +scikit-learn +requests diff --git a/trading/__init__.py b/trading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trading/pair_trading_env.py b/trading/pair_trading_env.py new file mode 100644 index 0000000..11af13b --- /dev/null +++ b/trading/pair_trading_env.py @@ -0,0 +1,51 @@ +import numpy as np +import pandas as pd +import gymnasium as gym +from gymnasium import spaces + + +class PairTradingEnv(gym.Env): + """A simple pair trading environment for RL.""" + + def __init__(self, prices: pd.DataFrame, pair: tuple[str, str]): + super().__init__() + self.prices = prices[list(pair)].dropna() + self.pair = pair + self.current_step = 0 + + # Observation: price spread + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,)) + # Actions: 0 -> hold, 1 -> long first/short second, 2 -> short first/long second + self.action_space = spaces.Discrete(3) + self.position = 0 # -1 short pair, 0 flat, 1 long pair + self.entry_price = 0.0 + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + self.current_step = 0 + self.position = 0 + self.entry_price = 0.0 + return self._get_obs(), {} + + def _get_obs(self): + spread = self.prices.iloc[self.current_step, 0] - self.prices.iloc[self.current_step, 1] + return np.array([spread], dtype=np.float32) + + def step(self, action: int): + done = False + reward = 0.0 + if action == 1 and self.position == 0: # long first, short second + self.position = 1 + self.entry_price = self._get_obs()[0] + elif action == 2 and self.position == 0: # short first, long second + self.position = -1 + self.entry_price = self._get_obs()[0] + elif action == 0 and self.position != 0: # close position + reward = self.position * (self.entry_price - self._get_obs()[0]) + self.position = 0 + + self.current_step += 1 + if self.current_step >= len(self.prices) - 1: + done = True + return self._get_obs(), reward, done, False, {} + From 6dc1306e041c961f0de4fe8b1630078a482cbf2b Mon Sep 17 00:00:00 2001 From: Erdem Arslan <69273456+Razbolt@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:22:21 +0100 Subject: [PATCH 2/3] Resolve merge conflicts and improve docs --- README.md | 5 ++++- data/download_data.py | 20 ++++++++++++-------- main.py | 2 -- requirements.txt | 1 + 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index fd8cbcf..9ebacd1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Pair Trading with Reinforcement Learning -This project demonstrates a simple research framework for creating a pair trading strategy using reinforcement learning. The repository contains utilities to download data, define a trading environment and train several RL agents. +This repository contains a toy framework for experimenting with pair trading using reinforcement learning. It provides helpers for fetching price data, a simple trading environment and training scripts for three different RL algorithms. The goal is to showcase the basic workflow rather than provide a production-ready solution. ## Setup @@ -20,3 +20,6 @@ python main.py --start 2015-01-01 --end 2020-01-01 The code is for educational use. It is not production ready and should be extended for a full dissertation. +Data is downloaded using `yfinance`. If internet access is unavailable the script +falls back to a small set of hard coded tickers. + diff --git a/data/download_data.py b/data/download_data.py index 503aa94..7e3dbe1 100644 --- a/data/download_data.py +++ b/data/download_data.py @@ -28,14 +28,18 @@ def get_sp500_tickers() -> List[str]: def download_price_history(tickers: List[str], start: str, end: str) -> pd.DataFrame: """Download adjusted close prices for the given tickers.""" - data = yf.download( - tickers, - start=start, - end=end, - group_by="ticker", - progress=False, - threads=False, - )["Adj Close"] + try: + data = yf.download( + tickers, + start=start, + end=end, + group_by="ticker", + progress=False, + threads=False, + )["Adj Close"] + except Exception as exc: + raise RuntimeError("Failed to download price data") from exc + data = data.dropna(axis=0, how="all") return data diff --git a/main.py b/main.py index 29c2ce5..bce3691 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ import pandas as pd from sklearn.preprocessing import StandardScaler from data.download_data import get_sp500_tickers, download_price_history -from trading.pair_trading_env import PairTradingEnv from agents.ppo_agent import train_ppo from agents.ddqn_agent import train_dqn from agents.sac_agent import train_sac @@ -30,7 +29,6 @@ def main(start: str, end: str): scaler = StandardScaler() prices[pair] = scaler.fit_transform(prices[pair]) - env = PairTradingEnv(prices, pair) print("Training PPO...") train_ppo(prices, pair, timesteps=5000) print("Training DQN...") diff --git a/requirements.txt b/requirements.txt index 5584d39..4e0e390 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ stable-baselines3 gymnasium scikit-learn requests +torch From c979bcadd970db6c742e6b21f9787728d76986d6 Mon Sep 17 00:00:00 2001 From: Erdem Arslan <69273456+Razbolt@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:22:27 +0100 Subject: [PATCH 3/3] Improve data download and pair selection --- README.md | 4 +++- data/download_data.py | 9 ++++++--- main.py | 11 ++++++++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9ebacd1..d655185 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ pip install -r requirements.txt ## Usage -The main entry point is `main.py`. It downloads historical data from Yahoo Finance for 20 S&P 500 tickers, prints their correlation matrix and then trains RL agents (PPO, DQN and SAC). +The main entry point is `main.py`. It downloads daily close prices (with dividends and splits auto-adjusted) for 20 S&P 500 tickers. The prices are saved under `data/` using a file name such as `S&P-2015-01-01--2020-01-01.csv`. The script prints the correlation matrix of the prices and then trains RL agents (PPO, DQN and SAC). ```bash python main.py --start 2015-01-01 --end 2020-01-01 @@ -23,3 +23,5 @@ The code is for educational use. It is not production ready and should be extend Data is downloaded using `yfinance`. If internet access is unavailable the script falls back to a small set of hard coded tickers. +Only the adjusted close prices are used when computing the stock correlation matrix. + diff --git a/data/download_data.py b/data/download_data.py index 7e3dbe1..b0879fd 100644 --- a/data/download_data.py +++ b/data/download_data.py @@ -2,6 +2,7 @@ import pandas as pd import requests from typing import List +import io def get_sp500_tickers() -> List[str]: @@ -13,8 +14,10 @@ def get_sp500_tickers() -> List[str]: """ url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" try: - tables = pd.read_html(requests.get(url, timeout=10).text) + html = requests.get(url, timeout=10).text + tables = pd.read_html(io.StringIO(html), flavor="bs4") tickers = tables[0]["Symbol"].tolist() + tickers = [t.replace(".", "-") for t in tickers] except Exception: # Fallback list (20 tickers) if network access is not available tickers = [ @@ -33,10 +36,10 @@ def download_price_history(tickers: List[str], start: str, end: str) -> pd.DataF tickers, start=start, end=end, - group_by="ticker", + auto_adjust=True, progress=False, threads=False, - )["Adj Close"] + )["Close"] except Exception as exc: raise RuntimeError("Failed to download price data") from exc diff --git a/main.py b/main.py index bce3691..5ecd3c3 100644 --- a/main.py +++ b/main.py @@ -5,20 +5,25 @@ from agents.ppo_agent import train_ppo from agents.ddqn_agent import train_dqn from agents.sac_agent import train_sac +from pathlib import Path def main(start: str, end: str): tickers = get_sp500_tickers() prices = download_price_history(tickers, start, end) + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + filename = data_dir / f"S&P-{start}--{end}.csv" + prices.to_csv(filename) # Pair selection based on correlation corr = prices.corr() print("Correlation matrix:") print(corr) corr_pairs = ( - corr.stack() - .reset_index() - .rename(columns={"level_0": "stock1", "level_1": "stock2", 0: "corr"}) + corr.stack(future_stack=True) + .reset_index(name="corr") + .rename(columns={"level_0": "stock1", "level_1": "stock2"}) ) corr_pairs = corr_pairs[corr_pairs["stock1"] != corr_pairs["stock2"]] corr_pairs = corr_pairs.sort_values(by="corr", ascending=False)