From dbe28dc2d063516ef0fcf0120dd14a1500a7a054 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Fri, 17 Jan 2025 01:01:01 -0500 Subject: [PATCH 1/6] Experimental changes to get first run --- requirements.txt | 37 +++++ setup.py | 13 ++ src/transformers_learn_mdp/__init__.py | 0 .../connect4_train_mcts.py | 126 ++++++++++++++++ src/transformers_learn_mdp/data_utils.py | 33 +++++ src/transformers_learn_mdp/dataset.py | 31 ++++ src/transformers_learn_mdp/model.py | 139 ++++++++++++++++++ src/transformers_learn_mdp/trainer.py | 83 +++++++++++ 8 files changed, 462 insertions(+) create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 src/transformers_learn_mdp/__init__.py create mode 100644 src/transformers_learn_mdp/connect4_train_mcts.py create mode 100644 src/transformers_learn_mdp/data_utils.py create mode 100644 src/transformers_learn_mdp/dataset.py create mode 100644 src/transformers_learn_mdp/model.py create mode 100644 src/transformers_learn_mdp/trainer.py diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4805d71 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,37 @@ +accelerate==1.2.1 +certifi==2024.12.14 +charset-normalizer==3.4.1 +filelock==3.16.1 +fsspec==2024.12.0 +huggingface-hub==0.27.1 +idna==3.10 +Jinja2==3.1.5 +MarkupSafe==3.0.2 +mpmath==1.3.0 +networkx==3.4.2 +numpy==2.2.1 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +packaging==24.2 +pillow==11.1.0 +psutil==6.1.1 +PyYAML==6.0.2 +requests==2.32.3 +safetensors==0.5.2 +sympy==1.13.1 +torch==2.5.1 +torchvision==0.20.1 +tqdm==4.67.1 +triton==3.1.0 +typing_extensions==4.12.2 +urllib3==2.3.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d8fd836 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +def read_requirements(): + with open("requirements.txt") as f: + return f.read().splitlines() + +setup( + name="transformers_learn_mdp", + version="0.1.0", + package_dir={"": "src"}, + packages=find_packages(where="src"), + install_requires=read_requirements() + ["mcts@git+https://github.com/metric-space/mcts.git"] +) diff --git a/src/transformers_learn_mdp/__init__.py b/src/transformers_learn_mdp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py new file mode 100644 index 0000000..53037b9 --- /dev/null +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -0,0 +1,126 @@ +import os +import sys +import pickle +import shutil +import torch +import argparse +from tqdm import tqdm + +#sys.path.append('../') + +from accelerate import Accelerator +from .dataset import EpisodeDataset, collate_fn +from .model import Config, GPTModel +from .trainer import train_model, validate_model +from torch.utils.data import DataLoader + +from .data_utils import information_parser + + +""" +Training pipeline for transformer on Connect-4 data generated through MCTS. +""" + +def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 15): + + accelerator = Accelerator() + + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) + valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) + + config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size) + model = GPTModel(config) + + optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) + + train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer) + + epoch = 0 + + model_path = None + min_loss = 1e10 + + train_losses = [] + valid_losses = [] + + for epoch in tqdm(range(epochs)): + accelerator.print(f'Epoch {epoch}') + + train_loss = train_model(model, train_loader, optimizer, accelerator) + valid_loss = validate_model(model, valid_loader, accelerator) + train_losses.append(train_loss) + valid_losses.append(valid_loss) + scheduler.step() + + if accelerator.is_main_process: + print(f'Validation Loss: {valid_loss:.8f}') + + model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" + accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path) + + if valid_loss < min_loss: + min_loss = valid_loss + model_path = model_save_path + + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + shutil.copy(model_path, save_directory) + + with open(f'train_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: + pickle.dump(train_losses, f) + with open(f'valid_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: + pickle.dump(valid_losses, f) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', type=int, default=1, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') + parser.add_argument('-s', type=int, default=23456, choices=[0, 1, 2], help='Seed') + parser.add_argument('-i', type=str, help='Input Path') + args = parser.parse_args() + if args.m == 0: + token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} + vocab_size = 43 + elif args.m == 1: + token_to_idx = {i: i + 1 for i in range(7)} + vocab_size = 8 + elif args.m == 2: + token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {i: i + 44 for i in range(7)} + vocab_size = 51 + token_to_idx[''] = 0 # Padding token + block_size = 42 + embed_size = 512 + num_layers = 8 + + path = '' + + with open(args.i, 'r') as f: + data = f.readlines() + data = information_parser(data) + agent1 = [[action for (_,action) in x] for x in data] + agent1 = [(actions[:-1], actions[1:]) for actions in agent1] + + + train_ratio = 0.8 + valid_ratio = 0.1 + + d1 = len(agent1) + + train = agent1[:int(train_ratio * d1)] + valid = agent1[int(train_ratio * d1):int((train_ratio + valid_ratio) * d1) ] + test = agent1[int((train_ratio + valid_ratio) * d1): ] + + print(len(train)) + print(len(valid)) + print(len(test)) + + train_dataset = EpisodeDataset(train, token_to_idx) + valid_dataset = EpisodeDataset(valid, token_to_idx) + + #train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") + +if __name__ == "__main__": + main() + diff --git a/src/transformers_learn_mdp/data_utils.py b/src/transformers_learn_mdp/data_utils.py new file mode 100644 index 0000000..62b36e8 --- /dev/null +++ b/src/transformers_learn_mdp/data_utils.py @@ -0,0 +1,33 @@ +from typing import List + + +def information_parser(info: List[str]): + """ + + + """ + # + parsed_info = [] + + for line in info: + temp = [] + raw = line.split(",") + counter = 0 + while counter < len(raw): + + leap_steps = int(raw[counter]) * 2 + counter += 1 + + q_values = {} + fragment = raw[counter:counter + leap_steps ] + zip_object = zip(fragment[::2], fragment[1::2]) + for key, value in zip_object: + q_values[int(key)] = float(value) + counter += leap_steps + + temp.append((q_values, int(raw[counter]))) + counter += 1 + + parsed_info.append(temp) + + return parsed_info \ No newline at end of file diff --git a/src/transformers_learn_mdp/dataset.py b/src/transformers_learn_mdp/dataset.py new file mode 100644 index 0000000..ff5ff4a --- /dev/null +++ b/src/transformers_learn_mdp/dataset.py @@ -0,0 +1,31 @@ +import torch + +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence + +class EpisodeDataset(Dataset): + + def __init__(self, data, token_to_idx): + self.data = data + self.token_to_idx = token_to_idx + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + X_sequence, Y_sequence = self.data[idx] + + X_indices = [self.token_to_idx[token] for token in X_sequence] + Y_indices = [self.token_to_idx[token] for token in Y_sequence] + + return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long) + +def collate_fn(batch): + + Xs, Ys = zip(*batch) + + Xs_padded = pad_sequence(Xs, batch_first=True, padding_value=0) + Ys_padded = pad_sequence(Ys, batch_first=True, padding_value=0) + + return Xs_padded, Ys_padded \ No newline at end of file diff --git a/src/transformers_learn_mdp/model.py b/src/transformers_learn_mdp/model.py new file mode 100644 index 0000000..33ab6d0 --- /dev/null +++ b/src/transformers_learn_mdp/model.py @@ -0,0 +1,139 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Config: + """ + Configuration for the GPT model including size parameters and dropout rates. + """ + def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer): + self.vocab_size = vocab_size # Vocabulary size + self.block_size = block_size # Input sequence length + self.n_embd = n_embd # Embedding dimension + self.n_head = n_head # Number of attention heads + self.n_layer = n_layer # Number of transformer layers + # Dropout rates + self.embd_pdrop = 0.1 + self.resid_pdrop = 0.1 + self.attn_pdrop = 0.1 + +class CausalSelfAttention(nn.Module): + """ + Causal self-attention module implementing scaled dot-product attention. + """ + def __init__(self, config): + super().__init__() + self.config = config + assert self.config.n_embd % self.config.n_head == 0, "embedding dimension must be divisible by the number of heads." + + # Key, query, and value linear transformations + self.key = nn.Linear(self.config.n_embd, self.config.n_embd) + self.query = nn.Linear(self.config.n_embd, self.config.n_embd) + self.value = nn.Linear(self.config.n_embd, self.config.n_embd) + + # Dropout layers + self.attn_drop = nn.Dropout(self.config.attn_pdrop) + self.resid_drop = nn.Dropout(self.config.resid_pdrop) + + # Output projection + self.proj = nn.Linear(self.config.n_embd, self.config.n_embd) + + # Causal mask to prevent attention to future tokens + self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).unsqueeze(0).unsqueeze(1)) + + def forward(self, x): + B, T, C = x.size() # Batch size, sequence length, embedding dimension + + # Calculate projections and reshape for multi-headed attention + k = self.key(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) + q = self.query(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) + v = self.value(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) + + # Compute the attention scores + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) @ v + + # Re-assemble all head outputs side by side + y = att.transpose(1, 2).contiguous().view(B, T, C) + + # Output projection + y = self.resid_drop(self.proj(y)) + return y + +class TransformerBlock(nn.Module): + """ + A single transformer block containing a causal self-attention layer and a feed-forward network. + """ + def __init__(self, config): + super().__init__() + self.attn = CausalSelfAttention(config) + self.ln1 = nn.LayerNorm(config.n_embd) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + self.ln2 = nn.LayerNorm(config.n_embd) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPTModel(nn.Module): + """ + The GPT model comprising an embedding layer, multiple transformer blocks, and a final output layer. + """ + def __init__(self, config): + super().__init__() + self.config = config + + # Embedding layers + self.tok_emb = nn.Embedding(self.config.vocab_size, self.config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, self.config.block_size, self.config.n_embd)) + self.drop = nn.Dropout(self.config.embd_pdrop) + + # Transformer blocks + self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(self.config.n_layer)]) + + # Final layer normalization and linear output layer + self.ln_f = nn.LayerNorm(self.config.n_embd) + self.head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, module): + # Initialize weights and biases + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, num_probe=None): + # Processing input + _, T = idx.size() + assert T <= self.config.block_size, "Input sequence too long." + + token_embeddings = self.tok_emb(idx) + position_embeddings = self.pos_emb[:, :T, :] + x = self.drop(token_embeddings + position_embeddings) + + # Retrieve Embedding if num_probe is specified + if num_probe is not None: + for block in self.blocks[:num_probe]: + x = block(x) + return x + + # Process through all blocks, then project to vocabulary size + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + return logits diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py new file mode 100644 index 0000000..28144f5 --- /dev/null +++ b/src/transformers_learn_mdp/trainer.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +from tqdm import tqdm + +def validate_model(model, valid_loader, accelerator): + + model.eval() + criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') + + valid_loss = torch.tensor(0.0).to(accelerator.device) + valid_data = torch.tensor(0.0).to(accelerator.device) + + with torch.no_grad(): + + for X_batch, Y_batch in valid_loader: + + logits = model(X_batch) + logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] + Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] + + # Assuming the padding token index is 0 + padding_token_index = 0 + mask = (Y_batch != padding_token_index).float() # Create a mask for valid positions + + loss = criterion(logits, Y_batch) # Calculate loss without reduction + masked_loss = loss * mask # Apply mask + + valid_loss += masked_loss.sum().item() # Sum the losses at valid positions + valid_data += mask.sum().item() # Count valid positions + + accelerator.wait_for_everyone() + + valid_loss = accelerator.gather(valid_loss).sum() + valid_data = accelerator.gather(valid_data).sum() + + if accelerator.is_main_process: + return (valid_loss / valid_data).item() + else: + return None + +def train_model(model, train_loader, optimizer, accelerator): + + model.train() + criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') + + train_loss = torch.tensor(0.0).to(accelerator.device) + train_data = torch.tensor(0.0).to(accelerator.device) + + for X_batch, Y_batch in tqdm(train_loader, desc="Training"): + + optimizer.zero_grad() + logits = model(X_batch) + + logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] + Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] + + padding_token_index = 0 # Assuming the padding token index is 0 + mask = (Y_batch != padding_token_index).float() + + loss = criterion(logits, Y_batch) + masked_loss = loss * mask + + loss_sum = masked_loss.sum() + data_sum = mask.sum() + + loss = loss_sum / data_sum + accelerator.backward(loss) + optimizer.step() + + train_loss += loss_sum.item() + train_data += mask.sum().item() + + accelerator.wait_for_everyone() + + train_loss = accelerator.gather(train_loss).sum() + train_data = accelerator.gather(train_data).sum() + + accelerator.print('Training Loss:', (train_loss / train_data).item()) + + return train_loss + + From cafe152c60c4ddef960c1f5a066f235071e24fcd Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 21 Jan 2025 00:45:25 -0500 Subject: [PATCH 2/6] Better lr scheduler, actions -> proper board coordinates, wandb addition --- requirements.txt | 21 ++++++++ .../connect4_train_mcts.py | 38 ++++++++------ src/transformers_learn_mdp/dataset.py | 49 ++++++++++++++++--- src/transformers_learn_mdp/trainer.py | 32 ++++++++++-- 4 files changed, 116 insertions(+), 24 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4805d71..4d1df9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,23 @@ accelerate==1.2.1 +annotated-types==0.7.0 certifi==2024.12.14 charset-normalizer==3.4.1 +click==8.1.8 +cloudpickle==3.1.1 +colorama==0.4.6 +docker-pycreds==0.4.0 filelock==3.16.1 fsspec==2024.12.0 +gitdb==4.0.12 +GitPython==3.1.44 +gym==0.26.2 +gym-notices==0.0.8 +gym_connect4 @ git+https://github.com/Danielhp95/gym-connect4.git@bfc12d659308dfcf1132a31aee9b52eceb8901b5 huggingface-hub==0.27.1 idna==3.10 Jinja2==3.1.5 MarkupSafe==3.0.2 +mcts @ git+https://github.com/metric-space/mcts.git@6028ada55d9690238c2db14d423c34d98698999a mpmath==1.3.0 networkx==3.4.2 numpy==2.2.1 @@ -24,14 +35,24 @@ nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 packaging==24.2 pillow==11.1.0 +platformdirs==4.3.6 +protobuf==5.29.3 psutil==6.1.1 +pydantic==2.10.5 +pydantic_core==2.27.2 PyYAML==6.0.2 requests==2.32.3 safetensors==0.5.2 +sentry-sdk==2.20.0 +setproctitle==1.3.4 +six==1.17.0 +smmap==5.0.2 sympy==1.13.1 torch==2.5.1 torchvision==0.20.1 tqdm==4.67.1 +-e git+ssh://git@github.com/llm-engineering/transformers-learn-MDP.git@dbe28dc2d063516ef0fcf0120dd14a1500a7a054#egg=transformers_learn_mdp triton==3.1.0 typing_extensions==4.12.2 urllib3==2.3.0 +wandb==0.19.4 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py index 53037b9..3dff8c2 100644 --- a/src/transformers_learn_mdp/connect4_train_mcts.py +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -4,10 +4,9 @@ import shutil import torch import argparse +import wandb from tqdm import tqdm -#sys.path.append('../') - from accelerate import Accelerator from .dataset import EpisodeDataset, collate_fn from .model import Config, GPTModel @@ -17,22 +16,25 @@ from .data_utils import information_parser +wandb.init(project="mdp-learning") + + """ Training pipeline for transformer on Connect-4 data generated through MCTS. """ - -def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 15): +def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 50): accelerator = Accelerator() - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) - valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) + train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn) + valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn) - config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size) + config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers // 2, n_embd=embed_size) model = GPTModel(config) optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.005, steps_per_epoch=len(train_loader), epochs=epochs) + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer) @@ -46,15 +48,20 @@ def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, for epoch in tqdm(range(epochs)): accelerator.print(f'Epoch {epoch}') + wandb.log({"Epoch": epoch}) - train_loss = train_model(model, train_loader, optimizer, accelerator) + train_loss = train_model(model, train_loader, optimizer, accelerator, scheduler) valid_loss = validate_model(model, valid_loader, accelerator) train_losses.append(train_loss) valid_losses.append(valid_loss) scheduler.step() + #print("Learning Rate: ", scheduler.get_last_lr()) + if accelerator.is_main_process: - print(f'Validation Loss: {valid_loss:.8f}') + val_loss_str = f'Validation loss {valid_loss:.8f}' + wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss}) + accelerator.print(val_loss_str) model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path) @@ -72,11 +79,13 @@ def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, pickle.dump(train_losses, f) with open(f'valid_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: pickle.dump(valid_losses, f) + + wandb.finish() def main(): parser = argparse.ArgumentParser() - parser.add_argument('-m', type=int, default=1, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') + parser.add_argument('-m', type=int, default=0, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') parser.add_argument('-s', type=int, default=23456, choices=[0, 1, 2], help='Seed') parser.add_argument('-i', type=str, help='Input Path') args = parser.parse_args() @@ -90,8 +99,8 @@ def main(): token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {i: i + 44 for i in range(7)} vocab_size = 51 token_to_idx[''] = 0 # Padding token - block_size = 42 - embed_size = 512 + block_size = 52 + embed_size = 64 num_layers = 8 path = '' @@ -100,7 +109,6 @@ def main(): data = f.readlines() data = information_parser(data) agent1 = [[action for (_,action) in x] for x in data] - agent1 = [(actions[:-1], actions[1:]) for actions in agent1] train_ratio = 0.8 @@ -119,7 +127,7 @@ def main(): train_dataset = EpisodeDataset(train, token_to_idx) valid_dataset = EpisodeDataset(valid, token_to_idx) - #train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") + train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") if __name__ == "__main__": main() diff --git a/src/transformers_learn_mdp/dataset.py b/src/transformers_learn_mdp/dataset.py index ff5ff4a..e30a3b8 100644 --- a/src/transformers_learn_mdp/dataset.py +++ b/src/transformers_learn_mdp/dataset.py @@ -2,22 +2,59 @@ from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence +import tqdm + +def actions_to_col_row(actions, board_height=6): + """ + Converts a sequence of Connect4 column moves into (column, row) pairs. + + Args: + actions (list): List of column indices (0-6) representing moves. + board_height (int): Number of rows in Connect4 (default: 6). + + Returns: + list of tuples: [(col, row), ...] where row is where the piece lands. + """ + heights = [0] * 7 # Track how filled each column is + col_row_sequence = [] + + for col in actions: + row = board_height - 1 - heights[col] # Compute the landing row + if row < 0: + raise ValueError(f"Invalid move: Column {col} is full!") + + col_row_sequence.append((row, col)) + heights[col] += 1 # Update column height + + return col_row_sequence + class EpisodeDataset(Dataset): - def __init__(self, data, token_to_idx): - self.data = data + def __init__(self, data, token_to_idx, packing_length=30,padding_value=0): self.token_to_idx = token_to_idx - + print("Tokenizing and packing the dataset") + self.packed_data = [] + + self.tokenized_data = [[self.token_to_idx[token] for token in actions_to_col_row(sequence)] for sequence in data] + # flatten the list and insert padding value at the end of each sequence + #self.data = [] + #for sequence in self.tokenized_data: + # self.data.extend(sequence) + # self.data.append(padding_value) + #del self.tokenized_data + #self.data = [self.data[i:i+packing_length] for i in range(0, len(self.data), packing_length)] + self.data = self.tokenized_data + def __len__(self): return len(self.data) def __getitem__(self, idx): - X_sequence, Y_sequence = self.data[idx] + sequence = self.data[idx] - X_indices = [self.token_to_idx[token] for token in X_sequence] - Y_indices = [self.token_to_idx[token] for token in Y_sequence] + X_indices = sequence[:-1] + Y_indices = sequence[1:] return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long) diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py index 28144f5..d4ae39e 100644 --- a/src/transformers_learn_mdp/trainer.py +++ b/src/transformers_learn_mdp/trainer.py @@ -39,10 +39,10 @@ def validate_model(model, valid_loader, accelerator): else: return None -def train_model(model, train_loader, optimizer, accelerator): +def train_model(model, train_loader, optimizer, accelerator, scheduler=None): model.train() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') + criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none', label_smoothing=0.0) train_loss = torch.tensor(0.0).to(accelerator.device) train_data = torch.tensor(0.0).to(accelerator.device) @@ -67,9 +67,32 @@ def train_model(model, train_loader, optimizer, accelerator): loss = loss_sum / data_sum accelerator.backward(loss) optimizer.step() + if scheduler is not None: + scheduler.step() train_loss += loss_sum.item() train_data += mask.sum().item() + + + grad_norms = [] + + for param in model.parameters(): + if param.grad is not None: + grad_norms.append(param.grad.norm().item()) + + if len(grad_norms) == 0: + return None # No gradients yet (e.g., before first backward pass) + + grad_tensor = torch.tensor(grad_norms) + + stats = { + "mean": grad_tensor.mean().item(), + "max": grad_tensor.max().item(), + "std": grad_tensor.std().item(), + "p95": grad_tensor.quantile(0.95).item() + } + + #accelerator.print('Gradient norm:', stats) accelerator.wait_for_everyone() @@ -78,6 +101,9 @@ def train_model(model, train_loader, optimizer, accelerator): accelerator.print('Training Loss:', (train_loss / train_data).item()) - return train_loss + if accelerator.is_main_process: + return (train_loss / train_data).item() + else: + return None From a03aae763a28384eddd7457f71f4d2770e4afb3e Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 22 Jan 2025 04:26:59 -0500 Subject: [PATCH 3/6] Add hydra to the show --- requirements.txt | 5 +- .../connect4_train_mcts.py | 165 +++++++++++------- src/transformers_learn_mdp/dataset.py | 2 +- src/transformers_learn_mdp/trainer.py | 20 ++- train.sh | 1 + 5 files changed, 122 insertions(+), 71 deletions(-) create mode 100644 train.sh diff --git a/requirements.txt b/requirements.txt index 4d1df9a..9d36c8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ accelerate==1.2.1 annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 certifi==2024.12.14 charset-normalizer==3.4.1 click==8.1.8 @@ -14,6 +15,7 @@ gym==0.26.2 gym-notices==0.0.8 gym_connect4 @ git+https://github.com/Danielhp95/gym-connect4.git@bfc12d659308dfcf1132a31aee9b52eceb8901b5 huggingface-hub==0.27.1 +hydra-core==1.3.2 idna==3.10 Jinja2==3.1.5 MarkupSafe==3.0.2 @@ -33,6 +35,7 @@ nvidia-cusparse-cu12==12.3.1.170 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 packaging==24.2 pillow==11.1.0 platformdirs==4.3.6 @@ -51,7 +54,7 @@ sympy==1.13.1 torch==2.5.1 torchvision==0.20.1 tqdm==4.67.1 --e git+ssh://git@github.com/llm-engineering/transformers-learn-MDP.git@dbe28dc2d063516ef0fcf0120dd14a1500a7a054#egg=transformers_learn_mdp +-e git+ssh://git@github.com/llm-engineering/transformers-learn-MDP.git@cafe152c60c4ddef960c1f5a066f235071e24fcd#egg=transformers_learn_mdp triton==3.1.0 typing_extensions==4.12.2 urllib3==2.3.0 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py index 3dff8c2..14bd55e 100644 --- a/src/transformers_learn_mdp/connect4_train_mcts.py +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -3,9 +3,10 @@ import pickle import shutil import torch -import argparse import wandb from tqdm import tqdm +import hydra +from omegaconf import DictConfig, OmegaConf, open_dict from accelerate import Accelerator from .dataset import EpisodeDataset, collate_fn @@ -14,57 +15,95 @@ from torch.utils.data import DataLoader from .data_utils import information_parser +from enum import Enum -wandb.init(project="mdp-learning") +class Mode(Enum): + STATE = 0 + ACTION = 1 + STATE_ACTION = 2 -""" -Training pipeline for transformer on Connect-4 data generated through MCTS. -""" -def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 50): - - accelerator = Accelerator() +def train(training_config, training_dataset, validation_dataset, token_to_idx, wandb): + + train_dataset = EpisodeDataset(training_dataset, token_to_idx) + valid_dataset = EpisodeDataset(validation_dataset, token_to_idx) - train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn) - valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn) + accelerator = Accelerator() - config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers // 2, n_embd=embed_size) + train_loader = DataLoader( + train_dataset, + batch_size=training_config.batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + valid_loader = DataLoader( + valid_dataset, + batch_size=training_config.batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + config = Config( + training_config.vocab_size, + training_config.seq_len, + n_layer=training_config.num_layers, + n_head=training_config.num_heads, + n_embd=training_config.embedding_size, + ) model = GPTModel(config) - optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) - scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.005, steps_per_epoch=len(train_loader), epochs=epochs) - #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) - - train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=training_config.lr, + weight_decay=training_config.weight_decay, + ) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=0.0006, + steps_per_epoch=len(train_loader), + epochs=training_config.epochs, + ) + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) + + train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare( + train_loader, valid_loader, model, scheduler, optimizer + ) epoch = 0 model_path = None min_loss = 1e10 - + train_losses = [] valid_losses = [] - for epoch in tqdm(range(epochs)): - accelerator.print(f'Epoch {epoch}') + for epoch in tqdm(range(training_config.epochs), desc="Epoch"): + accelerator.print(f"Epoch {epoch}") wandb.log({"Epoch": epoch}) - train_loss = train_model(model, train_loader, optimizer, accelerator, scheduler) + train_loss = train_model( + model, train_loader, optimizer, accelerator, scheduler, wandb + ) valid_loss = validate_model(model, valid_loader, accelerator) train_losses.append(train_loss) valid_losses.append(valid_loss) scheduler.step() - #print("Learning Rate: ", scheduler.get_last_lr()) + # print("Learning Rate: ", scheduler.get_last_lr()) + + mode = training_config.mode + seed = training_config.seed if accelerator.is_main_process: - val_loss_str = f'Validation loss {valid_loss:.8f}' + val_loss_str = f"Validation loss {valid_loss:.8f}" wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss}) accelerator.print(val_loss_str) model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" - accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path) + accelerator.save( + accelerator.unwrap_model(model).state_dict(), model_save_path + ) if valid_loss < min_loss: min_loss = valid_loss @@ -73,62 +112,68 @@ def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, accelerator.wait_for_everyone() if accelerator.is_main_process: - shutil.copy(model_path, save_directory) + shutil.copy(model_path, training_config.save_directory) - with open(f'train_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: + with open(f"train_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f: pickle.dump(train_losses, f) - with open(f'valid_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: + with open(f"valid_losses_mode_{mode}_seed_{seed}.pkl", "wb") as f: pickle.dump(valid_losses, f) - + wandb.finish() -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-m', type=int, default=0, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') - parser.add_argument('-s', type=int, default=23456, choices=[0, 1, 2], help='Seed') - parser.add_argument('-i', type=str, help='Input Path') - args = parser.parse_args() - if args.m == 0: +def split_dataset(data, train_ratio, valid_ratio): + train = data[: int(train_ratio * len(data))] + valid = data[ + int(train_ratio * len(data)) : int((train_ratio + valid_ratio) * len(data)) + ] + test = data[int((train_ratio + valid_ratio) * len(data)) :] + return train, valid, test + + +def mode_to_token_to_idx(mode): + if mode == 0: token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} vocab_size = 43 - elif args.m == 1: + elif mode == 1: token_to_idx = {i: i + 1 for i in range(7)} vocab_size = 8 - elif args.m == 2: - token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {i: i + 44 for i in range(7)} + elif mode == 2: + token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | { + i: i + 44 for i in range(7) + } vocab_size = 51 - token_to_idx[''] = 0 # Padding token - block_size = 52 - embed_size = 64 - num_layers = 8 - - path = '' - - with open(args.i, 'r') as f: + token_to_idx[""] = 0 # Padding token + return token_to_idx, vocab_size + + +@hydra.main(version_base=None, config_path="../../conf", config_name="config") +def main(cfg: DictConfig) -> None: + + training_config = cfg.training + + mode = training_config.mode + token_to_idx, vocab_size = mode_to_token_to_idx(mode) + + # Make this a function + with open(training_config.data_path, "r") as f: data = f.readlines() data = information_parser(data) - agent1 = [[action for (_,action) in x] for x in data] - + raw_dataset = [[action for (_, action) in x] for x in data] - train_ratio = 0.8 - valid_ratio = 0.1 - d1 = len(agent1) + training_dataset, validation_dataset, test_dataset = split_dataset( + raw_dataset, training_config.train_ratio, training_config.val_ratio + ) - train = agent1[:int(train_ratio * d1)] - valid = agent1[int(train_ratio * d1):int((train_ratio + valid_ratio) * d1) ] - test = agent1[int((train_ratio + valid_ratio) * d1): ] + with open_dict(training_config): + training_config["vocab_size"] = vocab_size + training_config["dataset_length"] = len(raw_dataset) - print(len(train)) - print(len(valid)) - print(len(test)) + wandb.init(project=cfg.wandb.project_name, config=dict(training_config)) - train_dataset = EpisodeDataset(train, token_to_idx) - valid_dataset = EpisodeDataset(valid, token_to_idx) + train(training_config, training_dataset, validation_dataset, token_to_idx, wandb) - train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") if __name__ == "__main__": main() - diff --git a/src/transformers_learn_mdp/dataset.py b/src/transformers_learn_mdp/dataset.py index e30a3b8..292ecef 100644 --- a/src/transformers_learn_mdp/dataset.py +++ b/src/transformers_learn_mdp/dataset.py @@ -36,7 +36,7 @@ def __init__(self, data, token_to_idx, packing_length=30,padding_value=0): print("Tokenizing and packing the dataset") self.packed_data = [] - self.tokenized_data = [[self.token_to_idx[token] for token in actions_to_col_row(sequence)] for sequence in data] + self.tokenized_data = [[self.token_to_idx[token] for token in sequence] for sequence in data] # flatten the list and insert padding value at the end of each sequence #self.data = [] #for sequence in self.tokenized_data: diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py index d4ae39e..6d1add1 100644 --- a/src/transformers_learn_mdp/trainer.py +++ b/src/transformers_learn_mdp/trainer.py @@ -15,9 +15,9 @@ def validate_model(model, valid_loader, accelerator): for X_batch, Y_batch in valid_loader: - logits = model(X_batch) + logits = model(X_batch)[:,1::2,:].contiguous() # Shape: [batch_size, seq_length, vocab_size] logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] + Y_batch = Y_batch[:,1::2].contiguous().view(-1) # Shape: [batch_size * seq_length] # Assuming the padding token index is 0 padding_token_index = 0 @@ -39,7 +39,7 @@ def validate_model(model, valid_loader, accelerator): else: return None -def train_model(model, train_loader, optimizer, accelerator, scheduler=None): +def train_model(model, train_loader, optimizer, accelerator, scheduler, wandb): model.train() criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none', label_smoothing=0.0) @@ -50,10 +50,10 @@ def train_model(model, train_loader, optimizer, accelerator, scheduler=None): for X_batch, Y_batch in tqdm(train_loader, desc="Training"): optimizer.zero_grad() - logits = model(X_batch) + logits = model(X_batch)[:,1::2,:].contiguous() logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] + Y_batch = Y_batch[:,1::2].contiguous().view(-1) # Shape: [batch_size * seq_length] padding_token_index = 0 # Assuming the padding token index is 0 mask = (Y_batch != padding_token_index).float() @@ -69,6 +69,7 @@ def train_model(model, train_loader, optimizer, accelerator, scheduler=None): optimizer.step() if scheduler is not None: scheduler.step() + wandb.log({"Learning Rate": scheduler.get_last_lr()[0]}) train_loss += loss_sum.item() train_data += mask.sum().item() @@ -86,13 +87,14 @@ def train_model(model, train_loader, optimizer, accelerator, scheduler=None): grad_tensor = torch.tensor(grad_norms) stats = { - "mean": grad_tensor.mean().item(), - "max": grad_tensor.max().item(), - "std": grad_tensor.std().item(), - "p95": grad_tensor.quantile(0.95).item() + "grad_mean": grad_tensor.mean().item(), + "grad_max": grad_tensor.max().item(), + "grad_std": grad_tensor.std().item(), + "grad_p95": grad_tensor.quantile(0.95).item() } #accelerator.print('Gradient norm:', stats) + wandb.log(stats) accelerator.wait_for_everyone() diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..82eb16c --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +python -m transformers_learn_mdp.connect4_train_mcts From 3766897ca8b51bf163da6766d469ef9de3707eb2 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 22 Jan 2025 17:58:17 -0500 Subject: [PATCH 4/6] Hydra conf --- conf/config.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 conf/config.yaml diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000..656a8ac --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,18 @@ +wandb: + project_name: "mdp-learning" + +training: + mode: 1 + num_layers: 8 + num_heads: 4 + batch_size: 128 + seq_len: 52 + train_ratio: 0.8 + val_ratio: 0.1 + data_path: "info.txt" + epochs: 50 + lr: 0.0001 + weight_decay: 0.01 + save_directory: "./" + embedding_size: 128 + seed: 1243 From c850e6ca44da57cd7427890bf7356c27bbd8adf0 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Sat, 1 Feb 2025 17:11:12 -0500 Subject: [PATCH 5/6] Refactor code + make it more configurable --- conf/config.yaml | 15 +- .../connect4_train_mcts.py | 50 ++-- src/transformers_learn_mdp/data_utils.py | 25 ++ src/transformers_learn_mdp/dataset.py | 28 +-- src/transformers_learn_mdp/trainer.py | 224 +++++++++++++++--- .../connect4_train_mcts.py | 118 --------- transformer_training_mcts/dataset.py | 31 --- transformer_training_mcts/model.py | 139 ----------- transformer_training_mcts/trainer.py | 83 ------- 9 files changed, 253 insertions(+), 460 deletions(-) delete mode 100644 transformer_training_mcts/connect4_train_mcts.py delete mode 100644 transformer_training_mcts/dataset.py delete mode 100644 transformer_training_mcts/model.py delete mode 100644 transformer_training_mcts/trainer.py diff --git a/conf/config.yaml b/conf/config.yaml index 656a8ac..84d8896 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -1,18 +1,21 @@ wandb: - project_name: "mdp-learning" + project_name: "board-representation-experiments" + id: "test-run-delete-soon-1" training: - mode: 1 + mode: 2 num_layers: 8 num_heads: 4 - batch_size: 128 - seq_len: 52 + batch_size: 64 + seq_len: 100 train_ratio: 0.8 val_ratio: 0.1 data_path: "info.txt" - epochs: 50 + epochs: 5 lr: 0.0001 weight_decay: 0.01 - save_directory: "./" + save_directory: "./save" embedding_size: 128 seed: 1243 + loss_type: 0 + seq_type: 1 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py index 14bd55e..7ae53d8 100644 --- a/src/transformers_learn_mdp/connect4_train_mcts.py +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -6,24 +6,19 @@ import wandb from tqdm import tqdm import hydra +import itertools from omegaconf import DictConfig, OmegaConf, open_dict from accelerate import Accelerator from .dataset import EpisodeDataset, collate_fn from .model import Config, GPTModel -from .trainer import train_model, validate_model +from .trainer import train_model, validate_model, Loss, Mode, SeqSubSet from torch.utils.data import DataLoader -from .data_utils import information_parser +from .data_utils import information_parser, actions_to_col_row from enum import Enum -class Mode(Enum): - STATE = 0 - ACTION = 1 - STATE_ACTION = 2 - - def train(training_config, training_dataset, validation_dataset, token_to_idx, wandb): train_dataset = EpisodeDataset(training_dataset, token_to_idx) @@ -53,14 +48,19 @@ def train(training_config, training_dataset, validation_dataset, token_to_idx, w ) model = GPTModel(config) - optimizer = torch.optim.AdamW( + #optimizer = torch.optim.AdamW( + # model.parameters(), + # lr=training_config.lr, + # weight_decay=training_config.weight_decay, + #) + optimizer = torch.optim.SGD( model.parameters(), lr=training_config.lr, weight_decay=training_config.weight_decay, ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=0.0006, + max_lr=0.0005, steps_per_epoch=len(train_loader), epochs=training_config.epochs, ) @@ -78,17 +78,24 @@ def train(training_config, training_dataset, validation_dataset, token_to_idx, w train_losses = [] valid_losses = [] + + + # TODO: this is just pulling things out from a config + mode = Mode(training_config.mode) + loss_type = Loss(training_config.loss_type) + seq_type = SeqSubSet(training_config.seq_type) + for epoch in tqdm(range(training_config.epochs), desc="Epoch"): accelerator.print(f"Epoch {epoch}") wandb.log({"Epoch": epoch}) train_loss = train_model( - model, train_loader, optimizer, accelerator, scheduler, wandb + model, train_loader, optimizer, accelerator, scheduler, wandb, mode, loss_type, seq_type ) - valid_loss = validate_model(model, valid_loader, accelerator) + valid_loss, p1_acc, p2_acc, total_acc = validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type) train_losses.append(train_loss) valid_losses.append(valid_loss) - scheduler.step() + #scheduler.step() # print("Learning Rate: ", scheduler.get_last_lr()) @@ -97,7 +104,7 @@ def train(training_config, training_dataset, validation_dataset, token_to_idx, w if accelerator.is_main_process: val_loss_str = f"Validation loss {valid_loss:.8f}" - wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss}) + wandb.log({"Validation Loss": valid_loss, "Training Loss": train_loss, "P1 Acc": p1_acc, "P2 Acc": p2_acc, "Total accuracy": total_acc}) accelerator.print(val_loss_str) model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" @@ -135,16 +142,23 @@ def mode_to_token_to_idx(mode): if mode == 0: token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} vocab_size = 43 + transformation = actions_to_col_row elif mode == 1: token_to_idx = {i: i + 1 for i in range(7)} vocab_size = 8 + transformation = lambda x: x elif mode == 2: token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | { i: i + 44 for i in range(7) } vocab_size = 51 + transformation = lambda x: list(itertools.chain(*zip(x,actions_to_col_row(x)))) token_to_idx[""] = 0 # Padding token - return token_to_idx, vocab_size + + token_to_idx[51] = 51 + vocab_size += 1 + + return token_to_idx, vocab_size, transformation @hydra.main(version_base=None, config_path="../../conf", config_name="config") @@ -153,13 +167,13 @@ def main(cfg: DictConfig) -> None: training_config = cfg.training mode = training_config.mode - token_to_idx, vocab_size = mode_to_token_to_idx(mode) + token_to_idx, vocab_size, transformation = mode_to_token_to_idx(mode) # Make this a function with open(training_config.data_path, "r") as f: data = f.readlines() data = information_parser(data) - raw_dataset = [[action for (_, action) in x] for x in data] + raw_dataset = [transformation([action for (_, action) in x]) for x in data] training_dataset, validation_dataset, test_dataset = split_dataset( @@ -170,7 +184,7 @@ def main(cfg: DictConfig) -> None: training_config["vocab_size"] = vocab_size training_config["dataset_length"] = len(raw_dataset) - wandb.init(project=cfg.wandb.project_name, config=dict(training_config)) + wandb.init(project=cfg.wandb.project_name, config=dict(training_config), id=cfg.wandb.id) train(training_config, training_dataset, validation_dataset, token_to_idx, wandb) diff --git a/src/transformers_learn_mdp/data_utils.py b/src/transformers_learn_mdp/data_utils.py index 62b36e8..b2aa957 100644 --- a/src/transformers_learn_mdp/data_utils.py +++ b/src/transformers_learn_mdp/data_utils.py @@ -1,6 +1,31 @@ from typing import List +def actions_to_col_row(actions, board_height=6): + """ + Converts a sequence of Connect4 column moves into (column, row) pairs. + + Args: + actions (list): List of column indices (0-6) representing moves. + board_height (int): Number of rows in Connect4 (default: 6). + + Returns: + list of tuples: [(col, row), ...] where row is where the piece lands. + """ + heights = [0] * 7 # Track how filled each column is + col_row_sequence = [] + + for col in actions: + row = board_height - 1 - heights[col] # Compute the landing row + if row < 0: + raise ValueError(f"Invalid move: Column {col} is full!") + + col_row_sequence.append((row, col)) + heights[col] += 1 # Update column height + + return col_row_sequence + + def information_parser(info: List[str]): """ diff --git a/src/transformers_learn_mdp/dataset.py b/src/transformers_learn_mdp/dataset.py index 292ecef..c5965b0 100644 --- a/src/transformers_learn_mdp/dataset.py +++ b/src/transformers_learn_mdp/dataset.py @@ -4,30 +4,6 @@ from torch.nn.utils.rnn import pad_sequence import tqdm -def actions_to_col_row(actions, board_height=6): - """ - Converts a sequence of Connect4 column moves into (column, row) pairs. - - Args: - actions (list): List of column indices (0-6) representing moves. - board_height (int): Number of rows in Connect4 (default: 6). - - Returns: - list of tuples: [(col, row), ...] where row is where the piece lands. - """ - heights = [0] * 7 # Track how filled each column is - col_row_sequence = [] - - for col in actions: - row = board_height - 1 - heights[col] # Compute the landing row - if row < 0: - raise ValueError(f"Invalid move: Column {col} is full!") - - col_row_sequence.append((row, col)) - heights[col] += 1 # Update column height - - return col_row_sequence - class EpisodeDataset(Dataset): @@ -36,7 +12,7 @@ def __init__(self, data, token_to_idx, packing_length=30,padding_value=0): print("Tokenizing and packing the dataset") self.packed_data = [] - self.tokenized_data = [[self.token_to_idx[token] for token in sequence] for sequence in data] + self.tokenized_data = [[51] + [self.token_to_idx[token] for token in sequence] + [51] for sequence in data] # flatten the list and insert padding value at the end of each sequence #self.data = [] #for sequence in self.tokenized_data: @@ -51,7 +27,7 @@ def __len__(self): def __getitem__(self, idx): - sequence = self.data[idx] + sequence = self.data[idx] X_indices = sequence[:-1] Y_indices = sequence[1:] diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py index 6d1add1..b17a8a6 100644 --- a/src/transformers_learn_mdp/trainer.py +++ b/src/transformers_learn_mdp/trainer.py @@ -1,111 +1,257 @@ import torch import torch.nn as nn +import itertools from tqdm import tqdm +import torch.nn.functional as F +from enum import Enum -def validate_model(model, valid_loader, accelerator): + +class Loss(Enum): + CrossEntropy = 0 + KLDivergence = 1 + + +class Mode(Enum): + STATE = 0 + ACTION = 1 + STATE_ACTION = 2 + + +# NOTE: as in whether to include just player 1 or the whole sequence +class SeqSubSet(Enum): + WHOLE = 0 + PLAYER_1 = 1 + + +def batch_one_hot(batch, number_of_classes): + """ + One-hot encode a batch of sequences. + """ + batch_size, seq_length = batch.size() + one_hot = torch.zeros(batch_size, seq_length, number_of_classes).to(batch.device) + one_hot.scatter_(2, batch.unsqueeze(-1), 1) + return one_hot + + +def loss_calc(loss_type, loss_fn, logits, target, indices, padding_token=0): + + assert loss_fn.reduction == "none" + + vocab_length = logits.shape[-1] + + target = ( + target[:, indices].contiguous().view(-1) + ) # expect this to be one dimensional + logits = logits[:, indices, :].contiguous().view(-1, logits.shape[-1]) + + if loss_type == Loss.KLDivergence: + """ + If loss is KL Divergence, the target needs to be a probability distribution over the + vocabulary + """ + logits = F.log_softmax(logits, dim=-1) + target = F.one_hot( + target, num_classes=vocab_length + ) # expect batch size is (batch_size x seq_length, vocab_length) + + loss = loss_fn(logits, target) + + if loss_type == Loss.KLDivergence: + """ + KLDivergence without reduction shoots out a seq of dim (seq_length ,vocab) + """ + loss = loss.mean(dim=-1) + + assert len(loss.shape) == 1 + + # ----- mask making ----------- + + mask = (target != padding_token).float() + + return (loss * mask, mask) + + +def logit_selection(length, mode, seq_type): + """ + + If mode is 0, select all the odd indices from 1 to length-1, because player 2 is randomly selecting the column. + + For mode 2 it's action state + + (a_0,s_0,a_1,s_1,a_0,s_1 ..... -> (a_0,s_0 ....), (s_0, a_1, ....) + + + """ + + whole_seq = range(length) + + if seq_type == SeqSubSet.WHOLE: + return list(whole_seq), list(whole_seq) + + if mode == 0 or mode == 1: + player_1 = list(range(0, length, 2)) + player_2 = [x for x in whole_seq if x not in player_1] + return (player_1, player_2) + + player_1 = [0, 1] # Why is 0 here? because action can be used to predict state + for i in range(3, length, 4): + player_1.extend(list(range(i, min(i + 3, length)))) + if player_1[-1] != length -1: + player_1.append(length-1) + + return (player_1, [x for x in whole_seq if x not in player_1]) + + +def criterion_f(loss_type): + if loss_type == Loss.CrossEntropy: + return nn.CrossEntropyLoss( + ignore_index=0, reduction="none", label_smoothing=0.0 + ) + else: + return nn.KLDivLoss(reduction="none") + + +# TODO: mode +def validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type): + model.eval() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') + criterion = criterion_f(loss_type) valid_loss = torch.tensor(0.0).to(accelerator.device) valid_data = torch.tensor(0.0).to(accelerator.device) + player_1_accuracy = torch.tensor(0.0).to(accelerator.device) + player_2_accuracy = torch.tensor(0.0).to(accelerator.device) + total_accuracy = torch.tensor(0.0).to(accelerator.device) + player_1_total = torch.tensor(0.0).to(accelerator.device) + player_2_total = torch.tensor(0.0).to(accelerator.device) with torch.no_grad(): for X_batch, Y_batch in valid_loader: - logits = model(X_batch)[:,1::2,:].contiguous() # Shape: [batch_size, seq_length, vocab_size] - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch[:,1::2].contiguous().view(-1) # Shape: [batch_size * seq_length] + p1_indices, p2_indices = logit_selection(X_batch.size(1), mode, seq_type) + + logits = model(X_batch) # Shape: [batch_size, seq_length, vocab_size] - # Assuming the padding token index is 0 - padding_token_index = 0 - mask = (Y_batch != padding_token_index).float() # Create a mask for valid positions + logits = F.log_softmax(logits, dim=-1) - loss = criterion(logits, Y_batch) # Calculate loss without reduction - masked_loss = loss * mask # Apply mask + p1_indices_, p2_indices_ = logit_selection(X_batch.size(1), mode, SeqSubSet.PLAYER_1) - valid_loss += masked_loss.sum().item() # Sum the losses at valid positions - valid_data += mask.sum().item() # Count valid positions + player_1_accuracy_ = ( + logits[:, p1_indices_].argmax(dim=-1) == Y_batch[:, p1_indices_] + ).float() + player_2_accuracy_ = ( + logits[:, p2_indices_].argmax(dim=-1) == Y_batch[:, p2_indices_] + ).float() + total_accuracy_ = (logits.argmax(dim=-1) == Y_batch).float() + + masked_loss, mask = loss_calc( + loss_type, criterion, logits, Y_batch, p1_indices + ) + + valid_loss += masked_loss.sum() # Sum the losses at valid positions + valid_data += mask.sum() # Count valid positions + player_1_accuracy += player_1_accuracy_.sum() + player_2_accuracy += player_2_accuracy_.sum() + total_accuracy += total_accuracy_.sum() + player_1_total += player_1_accuracy_.numel() + player_2_total += player_2_accuracy_.numel() accelerator.wait_for_everyone() - + valid_loss = accelerator.gather(valid_loss).sum() valid_data = accelerator.gather(valid_data).sum() + player_1_accuracy = accelerator.gather(player_1_accuracy).sum() + player_2_accuracy = accelerator.gather(player_2_accuracy).sum() + total_accuracy = accelerator.gather(total_accuracy).sum() + player_1_total = accelerator.gather(player_1_total).sum() + player_2_total = accelerator.gather(player_2_total).sum() if accelerator.is_main_process: - return (valid_loss / valid_data).item() + return ( + (valid_loss / valid_data).item(), + (player_1_accuracy / player_1_total).item(), + (player_2_accuracy / player_2_total).item(), + (total_accuracy / (player_1_total + player_2_total)).item(), + ) else: return None -def train_model(model, train_loader, optimizer, accelerator, scheduler, wandb): + +def train_model( + model, + train_loader, + optimizer, + accelerator, + scheduler, + wandb, + mode, + loss_type, + seq_type, +): model.train() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none', label_smoothing=0.0) + + criterion = criterion_f(loss_type) train_loss = torch.tensor(0.0).to(accelerator.device) train_data = torch.tensor(0.0).to(accelerator.device) for X_batch, Y_batch in tqdm(train_loader, desc="Training"): - + optimizer.zero_grad() - logits = model(X_batch)[:,1::2,:].contiguous() - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch[:,1::2].contiguous().view(-1) # Shape: [batch_size * seq_length] + p1_indices, p2_indices = logit_selection(X_batch.size(1), mode, seq_type) + + logits = model(X_batch) # Shape: [batch_size, seq_length, vocab_size] + + masked_loss, mask = loss_calc(None, criterion, logits, Y_batch, p1_indices) - padding_token_index = 0 # Assuming the padding token index is 0 - mask = (Y_batch != padding_token_index).float() - - loss = criterion(logits, Y_batch) - masked_loss = loss * mask - loss_sum = masked_loss.sum() data_sum = mask.sum() loss = loss_sum / data_sum accelerator.backward(loss) + # nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if scheduler is not None: scheduler.step() wandb.log({"Learning Rate": scheduler.get_last_lr()[0]}) - train_loss += loss_sum.item() + train_loss += loss_sum.item() train_data += mask.sum().item() - - + grad_norms = [] - + for param in model.parameters(): if param.grad is not None: grad_norms.append(param.grad.norm().item()) if len(grad_norms) == 0: return None # No gradients yet (e.g., before first backward pass) - + grad_tensor = torch.tensor(grad_norms) - + stats = { "grad_mean": grad_tensor.mean().item(), "grad_max": grad_tensor.max().item(), "grad_std": grad_tensor.std().item(), - "grad_p95": grad_tensor.quantile(0.95).item() + "grad_p95": grad_tensor.quantile(0.95).item(), } - #accelerator.print('Gradient norm:', stats) + # accelerator.print('Gradient norm:', stats) wandb.log(stats) accelerator.wait_for_everyone() - + train_loss = accelerator.gather(train_loss).sum() train_data = accelerator.gather(train_data).sum() - - accelerator.print('Training Loss:', (train_loss / train_data).item()) + + accelerator.print("Training Loss:", (train_loss / train_data).item()) if accelerator.is_main_process: return (train_loss / train_data).item() else: return None - - diff --git a/transformer_training_mcts/connect4_train_mcts.py b/transformer_training_mcts/connect4_train_mcts.py deleted file mode 100644 index ef40eed..0000000 --- a/transformer_training_mcts/connect4_train_mcts.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import sys -import pickle -import shutil -import torch -import argparse -from tqdm import tqdm - -sys.path.append('../') - -from accelerate import Accelerator -from dataset import EpisodeDataset, collate_fn -from model import Config, GPTModel -from trainer import train_model, validate_model -from torch.utils.data import DataLoader - -""" -Training pipeline for transformer on Connect-4 data generated through MCTS. -""" - -def train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, mode, seed, save_directory = None, epochs = 15): - - accelerator = Accelerator() - - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) - valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) - - config = Config(vocab_size, block_size, n_layer=num_layers, n_head=num_layers, n_embd=embed_size) - model = GPTModel(config) - - optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) - - train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer) - - epoch = 0 - - model_path = None - min_loss = 1e10 - - train_losses = [] - valid_losses = [] - - for epoch in tqdm(range(epochs)): - accelerator.print(f'Epoch {epoch}') - - train_loss = train_model(model, train_loader, optimizer, accelerator) - valid_loss = validate_model(model, valid_loader, accelerator) - train_losses.append(train_loss) - valid_losses.append(valid_loss) - scheduler.step() - - if accelerator.is_main_process: - print(f'Validation Loss: {valid_loss:.8f}') - - model_save_path = f"model_{epoch+1}_mode_{mode}_seed_{seed}.pth" - accelerator.save(accelerator.unwrap_model(model).state_dict(), model_save_path) - - if valid_loss < min_loss: - min_loss = valid_loss - model_path = model_save_path - - accelerator.wait_for_everyone() - - if accelerator.is_main_process: - shutil.copy(model_path, save_directory) - - with open(f'train_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: - pickle.dump(train_losses, f) - with open(f'valid_losses_mode_{mode}_seed_{seed}.pkl', 'wb') as f: - pickle.dump(valid_losses, f) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-m', type=int, default=0, choices=[0, 1, 2], help='Data Mode (state, action, state-action)') - parser.add_argument('-s', type=int, default=0, choices=[0, 1, 2], help='Seed') - args = parser.parse_args() - if args.m == 0: - token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} - vocab_size = 43 - elif args.m == 1: - token_to_idx = {i: i + 1 for i in range(7)} - vocab_size = 8 - elif args.m == 2: - token_to_idx = {(i, j): i * 7 + j + 1 for i in range(6) for j in range(7)} | {i: i + 44 for i in range(7)} - vocab_size = 51 - token_to_idx[''] = 0 # Padding token - block_size = 42 - embed_size = 512 - num_layers = 8 - - path = '' - - with open(os.path.join(path, rf'training_data/mcts/training_games_mode_{args.m}.pkl'), 'rb') as f: - agent1 = pickle.load(f) - - train_ratio = 0.8 - valid_ratio = 0.1 - - d1 = len(agent1) - - train = agent1[:int(train_ratio * d1)] - valid = agent1[int(train_ratio * d1):int((train_ratio + valid_ratio) * d1) ] - test = agent1[int((train_ratio + valid_ratio) * d1): ] - - print(len(train)) - print(len(valid)) - print(len(test)) - - train_dataset = EpisodeDataset(train, token_to_idx) - valid_dataset = EpisodeDataset(valid, token_to_idx) - - train_main(train_dataset, valid_dataset, vocab_size, block_size, num_layers, embed_size, args.m, args.s, "best_model") - -if __name__ == "__main__": - main() - diff --git a/transformer_training_mcts/dataset.py b/transformer_training_mcts/dataset.py deleted file mode 100644 index ff5ff4a..0000000 --- a/transformer_training_mcts/dataset.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from torch.utils.data import Dataset -from torch.nn.utils.rnn import pad_sequence - -class EpisodeDataset(Dataset): - - def __init__(self, data, token_to_idx): - self.data = data - self.token_to_idx = token_to_idx - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - - X_sequence, Y_sequence = self.data[idx] - - X_indices = [self.token_to_idx[token] for token in X_sequence] - Y_indices = [self.token_to_idx[token] for token in Y_sequence] - - return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long) - -def collate_fn(batch): - - Xs, Ys = zip(*batch) - - Xs_padded = pad_sequence(Xs, batch_first=True, padding_value=0) - Ys_padded = pad_sequence(Ys, batch_first=True, padding_value=0) - - return Xs_padded, Ys_padded \ No newline at end of file diff --git a/transformer_training_mcts/model.py b/transformer_training_mcts/model.py deleted file mode 100644 index 33ab6d0..0000000 --- a/transformer_training_mcts/model.py +++ /dev/null @@ -1,139 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -class Config: - """ - Configuration for the GPT model including size parameters and dropout rates. - """ - def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer): - self.vocab_size = vocab_size # Vocabulary size - self.block_size = block_size # Input sequence length - self.n_embd = n_embd # Embedding dimension - self.n_head = n_head # Number of attention heads - self.n_layer = n_layer # Number of transformer layers - # Dropout rates - self.embd_pdrop = 0.1 - self.resid_pdrop = 0.1 - self.attn_pdrop = 0.1 - -class CausalSelfAttention(nn.Module): - """ - Causal self-attention module implementing scaled dot-product attention. - """ - def __init__(self, config): - super().__init__() - self.config = config - assert self.config.n_embd % self.config.n_head == 0, "embedding dimension must be divisible by the number of heads." - - # Key, query, and value linear transformations - self.key = nn.Linear(self.config.n_embd, self.config.n_embd) - self.query = nn.Linear(self.config.n_embd, self.config.n_embd) - self.value = nn.Linear(self.config.n_embd, self.config.n_embd) - - # Dropout layers - self.attn_drop = nn.Dropout(self.config.attn_pdrop) - self.resid_drop = nn.Dropout(self.config.resid_pdrop) - - # Output projection - self.proj = nn.Linear(self.config.n_embd, self.config.n_embd) - - # Causal mask to prevent attention to future tokens - self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).unsqueeze(0).unsqueeze(1)) - - def forward(self, x): - B, T, C = x.size() # Batch size, sequence length, embedding dimension - - # Calculate projections and reshape for multi-headed attention - k = self.key(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) - q = self.query(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) - v = self.value(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2) - - # Compute the attention scores - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_drop(att) @ v - - # Re-assemble all head outputs side by side - y = att.transpose(1, 2).contiguous().view(B, T, C) - - # Output projection - y = self.resid_drop(self.proj(y)) - return y - -class TransformerBlock(nn.Module): - """ - A single transformer block containing a causal self-attention layer and a feed-forward network. - """ - def __init__(self, config): - super().__init__() - self.attn = CausalSelfAttention(config) - self.ln1 = nn.LayerNorm(config.n_embd) - self.mlp = nn.Sequential( - nn.Linear(config.n_embd, 4 * config.n_embd), - nn.GELU(), - nn.Linear(4 * config.n_embd, config.n_embd), - nn.Dropout(config.resid_pdrop), - ) - self.ln2 = nn.LayerNorm(config.n_embd) - - def forward(self, x): - x = x + self.attn(self.ln1(x)) - x = x + self.mlp(self.ln2(x)) - return x - -class GPTModel(nn.Module): - """ - The GPT model comprising an embedding layer, multiple transformer blocks, and a final output layer. - """ - def __init__(self, config): - super().__init__() - self.config = config - - # Embedding layers - self.tok_emb = nn.Embedding(self.config.vocab_size, self.config.n_embd) - self.pos_emb = nn.Parameter(torch.zeros(1, self.config.block_size, self.config.n_embd)) - self.drop = nn.Dropout(self.config.embd_pdrop) - - # Transformer blocks - self.blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(self.config.n_layer)]) - - # Final layer normalization and linear output layer - self.ln_f = nn.LayerNorm(self.config.n_embd) - self.head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False) - - self.apply(self._init_weights) - - def _init_weights(self, module): - # Initialize weights and biases - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, idx, num_probe=None): - # Processing input - _, T = idx.size() - assert T <= self.config.block_size, "Input sequence too long." - - token_embeddings = self.tok_emb(idx) - position_embeddings = self.pos_emb[:, :T, :] - x = self.drop(token_embeddings + position_embeddings) - - # Retrieve Embedding if num_probe is specified - if num_probe is not None: - for block in self.blocks[:num_probe]: - x = block(x) - return x - - # Process through all blocks, then project to vocabulary size - x = self.blocks(x) - x = self.ln_f(x) - logits = self.head(x) - - return logits diff --git a/transformer_training_mcts/trainer.py b/transformer_training_mcts/trainer.py deleted file mode 100644 index 28144f5..0000000 --- a/transformer_training_mcts/trainer.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn - -from tqdm import tqdm - -def validate_model(model, valid_loader, accelerator): - - model.eval() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') - - valid_loss = torch.tensor(0.0).to(accelerator.device) - valid_data = torch.tensor(0.0).to(accelerator.device) - - with torch.no_grad(): - - for X_batch, Y_batch in valid_loader: - - logits = model(X_batch) - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] - - # Assuming the padding token index is 0 - padding_token_index = 0 - mask = (Y_batch != padding_token_index).float() # Create a mask for valid positions - - loss = criterion(logits, Y_batch) # Calculate loss without reduction - masked_loss = loss * mask # Apply mask - - valid_loss += masked_loss.sum().item() # Sum the losses at valid positions - valid_data += mask.sum().item() # Count valid positions - - accelerator.wait_for_everyone() - - valid_loss = accelerator.gather(valid_loss).sum() - valid_data = accelerator.gather(valid_data).sum() - - if accelerator.is_main_process: - return (valid_loss / valid_data).item() - else: - return None - -def train_model(model, train_loader, optimizer, accelerator): - - model.train() - criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none') - - train_loss = torch.tensor(0.0).to(accelerator.device) - train_data = torch.tensor(0.0).to(accelerator.device) - - for X_batch, Y_batch in tqdm(train_loader, desc="Training"): - - optimizer.zero_grad() - logits = model(X_batch) - - logits = logits.view(-1, logits.size(-1)) # Shape: [batch_size * seq_length, vocab_size] - Y_batch = Y_batch.view(-1) # Shape: [batch_size * seq_length] - - padding_token_index = 0 # Assuming the padding token index is 0 - mask = (Y_batch != padding_token_index).float() - - loss = criterion(logits, Y_batch) - masked_loss = loss * mask - - loss_sum = masked_loss.sum() - data_sum = mask.sum() - - loss = loss_sum / data_sum - accelerator.backward(loss) - optimizer.step() - - train_loss += loss_sum.item() - train_data += mask.sum().item() - - accelerator.wait_for_everyone() - - train_loss = accelerator.gather(train_loss).sum() - train_data = accelerator.gather(train_data).sum() - - accelerator.print('Training Loss:', (train_loss / train_data).item()) - - return train_loss - - From db3f2d97d10ad7a61855dd923d8e80855314f884 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 5 Feb 2025 02:59:03 -0500 Subject: [PATCH 6/6] Trial changes --- conf/config.yaml | 12 ++--- .../connect4_train_mcts.py | 53 ++++++++++++++----- src/transformers_learn_mdp/trainer.py | 2 +- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 84d8896..0a55f5a 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -1,6 +1,6 @@ wandb: - project_name: "board-representation-experiments" - id: "test-run-delete-soon-1" + project_name: "board-representation-experiments-5th-feb" + id: "cross-entropy-gpt-warmup-scheduler-whole-seq-a" training: mode: 2 @@ -11,11 +11,11 @@ training: train_ratio: 0.8 val_ratio: 0.1 data_path: "info.txt" - epochs: 5 - lr: 0.0001 - weight_decay: 0.01 + epochs: 15 + lr: 0.00001 + weight_decay: 0.09 save_directory: "./save" embedding_size: 128 seed: 1243 loss_type: 0 - seq_type: 1 + seq_type: 0 diff --git a/src/transformers_learn_mdp/connect4_train_mcts.py b/src/transformers_learn_mdp/connect4_train_mcts.py index 7ae53d8..47d45f1 100644 --- a/src/transformers_learn_mdp/connect4_train_mcts.py +++ b/src/transformers_learn_mdp/connect4_train_mcts.py @@ -18,6 +18,28 @@ from .data_utils import information_parser, actions_to_col_row from enum import Enum +def get_lr_scheduler(optimizer, warmup_epochs, total_epochs, base_lr, max_lr): + """ + Combines warmup and cosine annealing for learning rate scheduling. + + Args: + optimizer: PyTorch optimizer + warmup_epochs: Number of warmup epochs + total_epochs: Total number of training epochs + base_lr: Starting learning rate (during warmup) + max_lr: Peak learning rate (after warmup) + + Returns: + scheduler: Learning rate scheduler + """ + def lr_lambda(epoch): + if epoch < warmup_epochs: + return 2*epoch # Linear warmup + else: + return 10*epoch + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + def train(training_config, training_dataset, validation_dataset, token_to_idx, wandb): @@ -48,24 +70,26 @@ def train(training_config, training_dataset, validation_dataset, token_to_idx, w ) model = GPTModel(config) - #optimizer = torch.optim.AdamW( - # model.parameters(), - # lr=training_config.lr, - # weight_decay=training_config.weight_decay, - #) - optimizer = torch.optim.SGD( + optimizer = torch.optim.AdamW( model.parameters(), lr=training_config.lr, weight_decay=training_config.weight_decay, ) - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, - max_lr=0.0005, - steps_per_epoch=len(train_loader), - epochs=training_config.epochs, - ) + #optimizer = torch.optim.SGD( + # model.parameters(), + # lr=training_config.lr, + # weight_decay=training_config.weight_decay, + #) + #scheduler = torch.optim.lr_scheduler.OneCycleLR( + # optimizer, + # max_lr=0.0005, + # steps_per_epoch=len(train_loader), + # epochs=training_config.epochs, + #) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs) + scheduler = get_lr_scheduler(optimizer,5, training_config.epochs, 0.00001, 0.001) + train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare( train_loader, valid_loader, model, scheduler, optimizer ) @@ -90,12 +114,13 @@ def train(training_config, training_dataset, validation_dataset, token_to_idx, w wandb.log({"Epoch": epoch}) train_loss = train_model( - model, train_loader, optimizer, accelerator, scheduler, wandb, mode, loss_type, seq_type + model, train_loader, optimizer, accelerator, None, wandb, mode, loss_type, seq_type ) valid_loss, p1_acc, p2_acc, total_acc = validate_model(model, valid_loader, accelerator, mode, loss_type, seq_type) train_losses.append(train_loss) valid_losses.append(valid_loss) - #scheduler.step() + scheduler.step() + accelerator.print({"Learning Rate": scheduler.get_last_lr()[0]}) # print("Learning Rate: ", scheduler.get_last_lr()) diff --git a/src/transformers_learn_mdp/trainer.py b/src/transformers_learn_mdp/trainer.py index b17a8a6..cadba8b 100644 --- a/src/transformers_learn_mdp/trainer.py +++ b/src/transformers_learn_mdp/trainer.py @@ -214,7 +214,7 @@ def train_model( loss = loss_sum / data_sum accelerator.backward(loss) - # nn.utils.clip_grad_norm_(model.parameters(), 1.0) + nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() if scheduler is not None: scheduler.step()