diff --git a/my-chesshacks-bot/src/__init__.py b/my-chesshacks-bot/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/my-chesshacks-bot/training/scripts/local_training.py b/my-chesshacks-bot/training/scripts/local_training.py new file mode 100644 index 0000000..7cc3c7b --- /dev/null +++ b/my-chesshacks-bot/training/scripts/local_training.py @@ -0,0 +1,113 @@ +import os + +import torch +from torch import optim +from torch.nn import MSELoss +from torch.utils.data import DataLoader +from tqdm import tqdm +from models.transformer import ChessTransformer +from stockfishdataset import StockfishDataset +from models.cnn import ChessCNN + +def train_stockfish(model, test_dataloader, train_dataloader, save_path, dtype=torch.float, epochs=64): + # Training setup + file_path = f".\\{save_path}\\Config.pt" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) + + criterion = MSELoss() + model.to("cuda", dtype) + + torch.autograd.set_detect_anomaly(True) + + # Training loop + step = 1 + for epoch in range(0, epochs): + batch_steps = 0 + epoch_total_loss = 0 + + batches = len(train_dataloader) + + for batch in tqdm(train_dataloader): + bitboards, values = batch + + values = values.to("cuda", dtype) + bitboards = bitboards.to("cuda", dtype) + + predicted_policy, predicted_values = model(bitboards) + + value_loss = criterion(predicted_values, values) + + loss = value_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + step += 1 + epoch_total_loss += loss.item() + batch_steps += 1 + + print(f"[Step {batch_steps} / {batches}] Train: MSE Loss = {loss.item():.4f}") + + evaluation_total_loss = evaluate_contrastive(model, test_dataloader) + + term = f"[Epoch {epoch}] Train: MSE Loss = {epoch_total_loss / batches:.4f\n}" + term += f"Test: MSE Loss = {evaluation_total_loss / len(test_dataloader):.4f}" + + term += "\n" + + print(term) + + torch.save(model, f".\\{save_path}\\Epoch-{epoch}.pt") + + +def evaluate_contrastive(model, dataloader, dtype=torch.float): + total_loss = 0 + criterion = MSELoss() + + with torch.no_grad(): + for batch in tqdm(dataloader): + bitboards, values = batch + + values = values.to("cuda", dtype) + bitboards = bitboards.to("cuda", dtype) + + predicted_policy, predicted_values = model(bitboards) + + value_loss = criterion(predicted_values, values) + + total_loss += value_loss.item() + + return total_loss + +if __name__ == "__main__": + print("Loading...") + test_bitboards = torch.load(".\\training\\data\\processed\\stockfish\\tactic_evals\\bitboards.pt", weights_only=False) + test_evaluations = torch.load(".\\training\\data\\processed\\stockfish\\tactic_evals\\evaluations.pt", weights_only=False) + test_set = StockfishDataset(test_bitboards, test_evaluations) + + train_bitboards = torch.load(".\\training\\data\\processed\\stockfish\\random_evals\\bitboards.pt", weights_only=False) + train_evaluations = torch.load(".\\training\\data\\processed\\stockfish\\random_evals\\evaluations.pt", weights_only=False) + train_set = StockfishDataset(train_bitboards, train_evaluations) + + train_dataloader = DataLoader( + train_set, + batch_size=8, + shuffle=True, + ) + + test_dataloader = DataLoader( + test_set, + batch_size=8, + shuffle=True, + ) + + save_path = ".\\data\\" + + model = ChessCNN() + #model = torch.compile(model) + + print("Training...") + train_stockfish(model, test_dataloader, train_dataloader, save_path=save_path, dtype=torch.float, epochs=64) \ No newline at end of file diff --git a/my-chesshacks-bot/training/scripts/models/transformer.py b/my-chesshacks-bot/training/scripts/models/transformer.py index e494931..53ad86d 100644 --- a/my-chesshacks-bot/training/scripts/models/transformer.py +++ b/my-chesshacks-bot/training/scripts/models/transformer.py @@ -44,6 +44,47 @@ def forward(self, input): return input + self.bias +class MultiheadAttention(nn.Module): + def __init__(self, d_model, num_heads, dropout=0.0): + super().__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.embed_dim = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + self.scaling = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, query, key, value): + + QB, QL, _ = query.shape + + # Linear projections + q = self.q_proj(query).view(QB, QL, self.num_heads, self.head_dim) + k = self.k_proj(key).view(QB, QL, self.num_heads, self.head_dim) + v = self.v_proj(value).view(QB, QL, self.num_heads, self.head_dim) + + q = q.half() + k = k.half() + v = v.half() + + output = flash_attn_func( + q, k, v, + dropout_p=0.0, + causal=False, + ) + + output = output.view(QB, QL, self.embed_dim) + + output = output.float() + + return output + class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): super().__init__() @@ -51,6 +92,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=True ) + #self.relative_bias = RelativePositionBias(nhead) # Feedforward network @@ -65,11 +107,11 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): self.dropout2 = nn.Dropout(dropout) def forward(self, src): - bias = self.relative_bias(seq_len=src.size(1)) # (nhead, 64, 64) + #bias = self.relative_bias(seq_len=src.size(1)) # (nhead, 64, 64) src2, _ = self.self_attn( src, src, src, - attn_mask=bias.repeat(src.size(0), 1, 1) # (batch*nhead, 64, 64) + #attn_mask=bias.repeat(src.size(0), 1, 1) # (batch*nhead, 64, 64) ) src = src + self.dropout1(src2) @@ -88,7 +130,7 @@ def __init__( self, d_model=256, nhead=8, - num_layers=4, + num_layers=6, dim_feedforward=1024, dropout=0.1 ): diff --git a/my-chesshacks-bot/training/scripts/stockfishdataset.py b/my-chesshacks-bot/training/scripts/stockfishdataset.py index 51562ab..4c668cf 100644 --- a/my-chesshacks-bot/training/scripts/stockfishdataset.py +++ b/my-chesshacks-bot/training/scripts/stockfishdataset.py @@ -1,4 +1,9 @@ import math +import os + +from tqdm import tqdm +import torch +from torch.utils.data import Dataset def square_index(rank, file): @@ -6,50 +11,85 @@ def square_index(rank, file): def fen_to_bitboards(fen): - piece_order = ["P", "N", "B", "R", "Q", "K", - "p", "n", "b", "r", "q", "k"] - bitboards = {p: 0 for p in piece_order} - board = fen.split()[0] # only the board portion + + # initialize 12 planes of 8×8 with zeros + planes = torch.zeros((12, 8, 8), dtype=torch.uint8) + + pieces = fen.split(" ") + board = pieces[0] + to_move = pieces[1] + + multiplier = 1 + if to_move == "w": + pieces = ["P", "N", "B", "R", "Q", "K", + "p", "n", "b", "r", "q", "k"] + else: + multiplier = -1 + pieces = ["p", "n", "b", "r", "q", + "k", "P", "N", "B", "R", "Q", "K"] + ranks = board.split("/") - for r, rank in enumerate(reversed(ranks), start=1): # r=1 → rank 1 - file = 0 + # ranks: fen[0] is rank 8 → row 0 + for row, rank in enumerate(ranks): + col = 0 for ch in rank: if ch.isdigit(): - file += int(ch) + col += int(ch) else: - if ch in bitboards: - sq = square_index(r, file) - bitboards[ch] |= (1 << sq) - file += 1 + if ch in pieces: + idx = pieces.index(ch) + planes[idx][row][col] = 1 + col += 1 + + return planes, multiplier + +class StockfishDataset(Dataset): + def __init__(self, bitboards, evaluations): - return bitboards + self.bitboards = bitboards + self.evaluations = evaluations + + def __len__(self): + return len(self.bitboards) + + def __getitem__(self, idx): + return self.bitboards[idx], self.evaluations[idx] if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--file", type=str, required=True, help="Path to the dataset file") + parser.add_argument("--input_file", type=str, required=True, help="Path to the dataset file") + parser.add_argument("--output_file", type=str, required=True, help="Path to the output file") + args = parser.parse_args() bitboards = [] evaluations = [] - with open(args.file) as f: - for line in f: - parts = f.readline().split() - bitboard = fen_to_bitboards(parts[0]) + with open(args.input_file) as f: + for line in tqdm(f): + parts = f.readline().strip("\n").split(',') + bitboard, to_move = fen_to_bitboards(parts[0]) + bitboard = bitboard.clone().detach() value = parts[1] - if value is "#": - evaluation = math.inf + if value[0] == "#": + evaluation = int(value[1:]) else: evaluation = int(value) + evaluation = evaluation * to_move + + evaluation = max(min(evaluation, 2^14), -2^14) + bitboards.append(bitboard) evaluations.append(evaluation) -fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1" -bitboards = fen_to_bitboards(fen) \ No newline at end of file + os.makedirs(os.path.dirname(args.output_file), exist_ok=True) + + torch.save(bitboards, args.output_file + "bitboards.pt") + torch.save(evaluations, args.output_file + "evaluations.pt") \ No newline at end of file