Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
113 changes: 113 additions & 0 deletions my-chesshacks-bot/training/scripts/local_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os

import torch
from torch import optim
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.transformer import ChessTransformer
from stockfishdataset import StockfishDataset
from models.cnn import ChessCNN

def train_stockfish(model, test_dataloader, train_dataloader, save_path, dtype=torch.float, epochs=64):
# Training setup
file_path = f".\\{save_path}\\Config.pt"
os.makedirs(os.path.dirname(file_path), exist_ok=True)

optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

criterion = MSELoss()
model.to("cuda", dtype)

torch.autograd.set_detect_anomaly(True)

# Training loop
step = 1
for epoch in range(0, epochs):
batch_steps = 0
epoch_total_loss = 0

batches = len(train_dataloader)

for batch in tqdm(train_dataloader):
bitboards, values = batch

values = values.to("cuda", dtype)
bitboards = bitboards.to("cuda", dtype)

predicted_policy, predicted_values = model(bitboards)

value_loss = criterion(predicted_values, values)

loss = value_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()

step += 1
epoch_total_loss += loss.item()
batch_steps += 1

print(f"[Step {batch_steps} / {batches}] Train: MSE Loss = {loss.item():.4f}")

evaluation_total_loss = evaluate_contrastive(model, test_dataloader)

term = f"[Epoch {epoch}] Train: MSE Loss = {epoch_total_loss / batches:.4f\n}"
term += f"Test: MSE Loss = {evaluation_total_loss / len(test_dataloader):.4f}"

term += "\n"

print(term)

torch.save(model, f".\\{save_path}\\Epoch-{epoch}.pt")


def evaluate_contrastive(model, dataloader, dtype=torch.float):
total_loss = 0
criterion = MSELoss()

with torch.no_grad():
for batch in tqdm(dataloader):
bitboards, values = batch

values = values.to("cuda", dtype)
bitboards = bitboards.to("cuda", dtype)

predicted_policy, predicted_values = model(bitboards)

value_loss = criterion(predicted_values, values)

total_loss += value_loss.item()

return total_loss

if __name__ == "__main__":
print("Loading...")
test_bitboards = torch.load(".\\training\\data\\processed\\stockfish\\tactic_evals\\bitboards.pt", weights_only=False)
test_evaluations = torch.load(".\\training\\data\\processed\\stockfish\\tactic_evals\\evaluations.pt", weights_only=False)
test_set = StockfishDataset(test_bitboards, test_evaluations)

train_bitboards = torch.load(".\\training\\data\\processed\\stockfish\\random_evals\\bitboards.pt", weights_only=False)
train_evaluations = torch.load(".\\training\\data\\processed\\stockfish\\random_evals\\evaluations.pt", weights_only=False)
train_set = StockfishDataset(train_bitboards, train_evaluations)

train_dataloader = DataLoader(
train_set,
batch_size=8,
shuffle=True,
)

test_dataloader = DataLoader(
test_set,
batch_size=8,
shuffle=True,
)

save_path = ".\\data\\"

model = ChessCNN()
#model = torch.compile(model)

print("Training...")
train_stockfish(model, test_dataloader, train_dataloader, save_path=save_path, dtype=torch.float, epochs=64)
48 changes: 45 additions & 3 deletions my-chesshacks-bot/training/scripts/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,55 @@ def forward(self, input):
return input + self.bias


class MultiheadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.0):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

self.embed_dim = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.scaling = self.head_dim ** -0.5

self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value):

QB, QL, _ = query.shape

# Linear projections
q = self.q_proj(query).view(QB, QL, self.num_heads, self.head_dim)
k = self.k_proj(key).view(QB, QL, self.num_heads, self.head_dim)
v = self.v_proj(value).view(QB, QL, self.num_heads, self.head_dim)

q = q.half()
k = k.half()
v = v.half()

output = flash_attn_func(
q, k, v,
dropout_p=0.0,
causal=False,
)

output = output.view(QB, QL, self.embed_dim)

output = output.float()

return output

class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()

self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True
)

#self.relative_bias = RelativePositionBias(nhead)

# Feedforward network
Expand All @@ -65,11 +107,11 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
self.dropout2 = nn.Dropout(dropout)

def forward(self, src):
bias = self.relative_bias(seq_len=src.size(1)) # (nhead, 64, 64)
#bias = self.relative_bias(seq_len=src.size(1)) # (nhead, 64, 64)

src2, _ = self.self_attn(
src, src, src,
attn_mask=bias.repeat(src.size(0), 1, 1) # (batch*nhead, 64, 64)
#attn_mask=bias.repeat(src.size(0), 1, 1) # (batch*nhead, 64, 64)
)

src = src + self.dropout1(src2)
Expand All @@ -88,7 +130,7 @@ def __init__(
self,
d_model=256,
nhead=8,
num_layers=4,
num_layers=6,
dim_feedforward=1024,
dropout=0.1
):
Expand Down
82 changes: 61 additions & 21 deletions my-chesshacks-bot/training/scripts/stockfishdataset.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,95 @@
import math
import os

from tqdm import tqdm
import torch
from torch.utils.data import Dataset


def square_index(rank, file):
return (rank - 1) * 8 + file


def fen_to_bitboards(fen):
piece_order = ["P", "N", "B", "R", "Q", "K",
"p", "n", "b", "r", "q", "k"]
bitboards = {p: 0 for p in piece_order}

board = fen.split()[0] # only the board portion

# initialize 12 planes of 8×8 with zeros
planes = torch.zeros((12, 8, 8), dtype=torch.uint8)

pieces = fen.split(" ")
board = pieces[0]
to_move = pieces[1]

multiplier = 1
if to_move == "w":
pieces = ["P", "N", "B", "R", "Q", "K",
"p", "n", "b", "r", "q", "k"]
else:
multiplier = -1
pieces = ["p", "n", "b", "r", "q",
"k", "P", "N", "B", "R", "Q", "K"]

ranks = board.split("/")

for r, rank in enumerate(reversed(ranks), start=1): # r=1 → rank 1
file = 0
# ranks: fen[0] is rank 8 → row 0
for row, rank in enumerate(ranks):
col = 0
for ch in rank:
if ch.isdigit():
file += int(ch)
col += int(ch)
else:
if ch in bitboards:
sq = square_index(r, file)
bitboards[ch] |= (1 << sq)
file += 1
if ch in pieces:
idx = pieces.index(ch)
planes[idx][row][col] = 1
col += 1

return planes, multiplier

class StockfishDataset(Dataset):
def __init__(self, bitboards, evaluations):

return bitboards
self.bitboards = bitboards
self.evaluations = evaluations

def __len__(self):
return len(self.bitboards)

def __getitem__(self, idx):
return self.bitboards[idx], self.evaluations[idx]


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, required=True, help="Path to the dataset file")
parser.add_argument("--input_file", type=str, required=True, help="Path to the dataset file")
parser.add_argument("--output_file", type=str, required=True, help="Path to the output file")

args = parser.parse_args()

bitboards = []
evaluations = []

with open(args.file) as f:
for line in f:
parts = f.readline().split()
bitboard = fen_to_bitboards(parts[0])
with open(args.input_file) as f:
for line in tqdm(f):
parts = f.readline().strip("\n").split(',')
bitboard, to_move = fen_to_bitboards(parts[0])
bitboard = bitboard.clone().detach()
value = parts[1]

if value is "#":
evaluation = math.inf
if value[0] == "#":
evaluation = int(value[1:])
else:
evaluation = int(value)

evaluation = evaluation * to_move

evaluation = max(min(evaluation, 2^14), -2^14)

bitboards.append(bitboard)
evaluations.append(evaluation)

fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1"
bitboards = fen_to_bitboards(fen)
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)

torch.save(bitboards, args.output_file + "bitboards.pt")
torch.save(evaluations, args.output_file + "evaluations.pt")