Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Python
__pycache__/
*.py[cod]
*.egg
*.egg-info/
*.pyo

# Virtual environments
venv/
.env

# Jupyter notebooks
.ipynb_checkpoints

# Data
*.csv
*.h5
*.parquet

# OS files
.DS_Store

29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,27 @@
# pairtrader
Stastically Arbitrage opportunity may be exploided via RL/DL. Stay tunned for more
# Pair Trading with Reinforcement Learning

This repository contains a toy framework for experimenting with pair trading using reinforcement learning. It provides helpers for fetching price data, a simple trading environment and training scripts for three different RL algorithms. The goal is to showcase the basic workflow rather than provide a production-ready solution.

## Setup

```bash
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

## Usage

The main entry point is `main.py`. It downloads daily close prices (with dividends and splits auto-adjusted) for 20 S&P 500 tickers. The prices are saved under `data/` using a file name such as `S&P-2015-01-01--2020-01-01.csv`. The script prints the correlation matrix of the prices and then trains RL agents (PPO, DQN and SAC).

```bash
python main.py --start 2015-01-01 --end 2020-01-01
```

The code is for educational use. It is not production ready and should be extended for a full dissertation.

Data is downloaded using `yfinance`. If internet access is unavailable the script
falls back to a small set of hard coded tickers.

Only the adjusted close prices are used when computing the stock correlation matrix.

Empty file added agents/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions agents/ddqn_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from trading.pair_trading_env import PairTradingEnv
import pandas as pd


def train_dqn(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> DQN:
env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)])
model = DQN("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=timesteps)
return model

12 changes: 12 additions & 0 deletions agents/ppo_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from trading.pair_trading_env import PairTradingEnv
import pandas as pd


def train_ppo(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> PPO:
env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)])
model = PPO("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=timesteps)
return model

12 changes: 12 additions & 0 deletions agents/sac_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv
from trading.pair_trading_env import PairTradingEnv
import pandas as pd


def train_sac(prices: pd.DataFrame, pair: tuple[str, str], timesteps: int = 10000) -> SAC:
env = DummyVecEnv([lambda: PairTradingEnv(prices, pair)])
model = SAC("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=timesteps)
return model

Empty file added data/__init__.py
Empty file.
48 changes: 48 additions & 0 deletions data/download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import yfinance as yf
import pandas as pd
import requests
from typing import List
import io


def get_sp500_tickers() -> List[str]:
"""Return a list of S&P 500 tickers.

Attempts to fetch the component list from Wikipedia. If the request fails
(e.g. due to lack of internet access), a fallback list of large cap stocks
is returned.
"""
url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
try:
html = requests.get(url, timeout=10).text
tables = pd.read_html(io.StringIO(html), flavor="bs4")
tickers = tables[0]["Symbol"].tolist()
tickers = [t.replace(".", "-") for t in tickers]
except Exception:
# Fallback list (20 tickers) if network access is not available
tickers = [
"AAPL", "MSFT", "GOOGL", "AMZN", "META",
"NVDA", "BRK-B", "JPM", "JNJ", "V",
"PG", "XOM", "UNH", "HD", "MA",
"CVX", "LLY", "MRK", "ABBV", "PEP",
]
return tickers[:20]


def download_price_history(tickers: List[str], start: str, end: str) -> pd.DataFrame:
"""Download adjusted close prices for the given tickers."""
try:
data = yf.download(
tickers,
start=start,
end=end,
auto_adjust=True,
progress=False,
threads=False,
)["Close"]
except Exception as exc:
raise RuntimeError("Failed to download price data") from exc

data = data.dropna(axis=0, how="all")
return data

51 changes: 51 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import pandas as pd
from sklearn.preprocessing import StandardScaler
from data.download_data import get_sp500_tickers, download_price_history
from agents.ppo_agent import train_ppo
from agents.ddqn_agent import train_dqn
from agents.sac_agent import train_sac
from pathlib import Path


def main(start: str, end: str):
tickers = get_sp500_tickers()
prices = download_price_history(tickers, start, end)
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
filename = data_dir / f"S&P-{start}--{end}.csv"
prices.to_csv(filename)

# Pair selection based on correlation
corr = prices.corr()
print("Correlation matrix:")
print(corr)
corr_pairs = (
corr.stack(future_stack=True)
.reset_index(name="corr")
.rename(columns={"level_0": "stock1", "level_1": "stock2"})
)
corr_pairs = corr_pairs[corr_pairs["stock1"] != corr_pairs["stock2"]]
corr_pairs = corr_pairs.sort_values(by="corr", ascending=False)
pair = (corr_pairs.iloc[0]["stock1"], corr_pairs.iloc[0]["stock2"])
print(f"Selected pair: {pair}")

# Scale spreads for environment
scaler = StandardScaler()
prices[pair] = scaler.fit_transform(prices[pair])

print("Training PPO...")
train_ppo(prices, pair, timesteps=5000)
print("Training DQN...")
train_dqn(prices, pair, timesteps=5000)
print("Training SAC...")
train_sac(prices, pair, timesteps=5000)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--start", type=str, required=True)
parser.add_argument("--end", type=str, required=True)
args = parser.parse_args()
main(args.start, args.end)

9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pandas
numpy
matplotlib
yfinance
stable-baselines3
gymnasium
scikit-learn
requests
torch
Empty file added trading/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions trading/pair_trading_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces


class PairTradingEnv(gym.Env):
"""A simple pair trading environment for RL."""

def __init__(self, prices: pd.DataFrame, pair: tuple[str, str]):
super().__init__()
self.prices = prices[list(pair)].dropna()
self.pair = pair
self.current_step = 0

# Observation: price spread
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,))
# Actions: 0 -> hold, 1 -> long first/short second, 2 -> short first/long second
self.action_space = spaces.Discrete(3)
self.position = 0 # -1 short pair, 0 flat, 1 long pair
self.entry_price = 0.0

def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.current_step = 0
self.position = 0
self.entry_price = 0.0
return self._get_obs(), {}

def _get_obs(self):
spread = self.prices.iloc[self.current_step, 0] - self.prices.iloc[self.current_step, 1]
return np.array([spread], dtype=np.float32)

def step(self, action: int):
done = False
reward = 0.0
if action == 1 and self.position == 0: # long first, short second
self.position = 1
self.entry_price = self._get_obs()[0]
elif action == 2 and self.position == 0: # short first, long second
self.position = -1
self.entry_price = self._get_obs()[0]
elif action == 0 and self.position != 0: # close position
reward = self.position * (self.entry_price - self._get_obs()[0])
self.position = 0

self.current_step += 1
if self.current_step >= len(self.prices) - 1:
done = True
return self._get_obs(), reward, done, False, {}