From 387535b23783c0167bc31eb70476805a2c868f5b Mon Sep 17 00:00:00 2001 From: SubhanshuSethi <134162561+Subhanshusethi@users.noreply.github.com> Date: Tue, 26 Nov 2024 03:43:56 +0530 Subject: [PATCH] added debugging to train.py 1. Added a valid indices check which, while loading the tokens file, ensures that the number of captions matches the number of embeddings. If not, the mismatched entries are filtered out. 2. Fine-tuning GPT-2 is a major task that highly relies on the dataset. Removed single letters, special characters, and stop words (as defined by NLTK by default) to reduce the impact of connector words while training in the embedding space. --- train.py | 106 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/train.py b/train.py index 61e945e..029b0f9 100644 --- a/train.py +++ b/train.py @@ -11,15 +11,84 @@ import argparse import json from typing import Tuple, Optional, Union - +import re +from nltk.corpus import stopwords class MappingType(Enum): MLP = 'mlp' Transformer = 'transformer' + class ClipCocoDataset(Dataset): + def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2", + normalize_prefix=False): + self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type) + self.prefix_length = prefix_length + self.normalize_prefix = normalize_prefix + self.stop_words = set(stopwords.words('english')) + # Load data + with open(data_path, 'rb') as f: + all_data = pickle.load(f) + print("Data size is %0d" % len(all_data["clip_embedding"])) + sys.stdout.flush() + + self.prefixes = all_data["clip_embedding"] + captions_raw = all_data["captions"] + self.image_ids = [caption["image_id"] for caption in captions_raw] + self.captions = [caption['caption'] for caption in captions_raw] + + # Tokenized captions file path + tokens_file_path = f"/kaggle/working/{os.path.basename(data_path).split('.')[0]}_{gpt2_type}_tokens.pkl" + + if os.path.isfile(tokens_file_path): + print("Loading tokenized captions from pickle file...") + with open(tokens_file_path, 'rb') as f: + self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f) + else: + print("Tokenizing captions and saving to pickle file...") + self.captions_tokens = [] + self.caption2embedding = [] + max_seq_len = 0 + + for caption in captions_raw: + processed_caption = self.preprocess_caption(caption['caption']) + tokens = torch.tensor(self.tokenizer.encode(processed_caption), dtype=torch.int64) + self.captions_tokens.append(tokens) + self.caption2embedding.append(caption["clip_embedding"]) # Should be an index + max_seq_len = max(max_seq_len, tokens.shape[0]) + + # Save tokenized captions to pickle file + with open(tokens_file_path, 'wb') as f: + pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f) + # Validate indices + valid_indices = [i for i, idx in enumerate(self.caption2embedding) if idx < len(self.prefixes)] + if len(valid_indices) < len(self.caption2embedding): + print(f"Found {len(self.caption2embedding) - len(valid_indices)} invalid indices. Filtering out invalid captions.") + self.captions_tokens = [self.captions_tokens[i] for i in valid_indices] + self.caption2embedding = [self.caption2embedding[i] for i in valid_indices] + self.image_ids = [self.image_ids[i] for i in valid_indices] + self.captions = [self.captions[i] for i in valid_indices] + + # Compute max sequence length based on tokenized data + all_len = torch.tensor([len(tokens) for tokens in self.captions_tokens]).float() + self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max())) + def preprocess_caption(self, caption: str) -> str: + """ + Preprocesses the caption by normalizing case, removing special characters, + redundant white spaces, and stopwords. + """ + # Convert to lowercase + caption = caption.lower() + # Remove special characters + caption = re.sub(r"[^\w\s]", "", caption) # Retain only letters, digits, and spaces + # Remove redundant white spaces + caption = re.sub(r"\s+", " ", caption).strip() + # Remove stopwords + words = caption.split() + caption = " ".join([word for word in words if word not in self.stop_words]) + return caption def __len__(self) -> int: return len(self.captions_tokens) @@ -32,49 +101,20 @@ def pad_tokens(self, item: int): elif padding < 0: tokens = tokens[:self.max_seq_len] self.captions_tokens[item] = tokens - mask = tokens.ge(0) # mask is zero where we out of sequence + mask = tokens.ge(0) # Mask is zero where we are out of sequence tokens[~mask] = 0 mask = mask.float() - mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask + mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # Adding prefix mask return tokens, mask def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]: tokens, mask = self.pad_tokens(item) - prefix = self.prefixes[self.caption2embedding[item]] + prefix = self.prefixes[self.caption2embedding[item]] # Use index to get embedding if self.normalize_prefix: prefix = prefix.float() prefix = prefix / prefix.norm(2, -1) return tokens, mask, prefix - def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2", - normalize_prefix=False): - self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type) - self.prefix_length = prefix_length - self.normalize_prefix = normalize_prefix - with open(data_path, 'rb') as f: - all_data = pickle.load(f) - print("Data size is %0d" % len(all_data["clip_embedding"])) - sys.stdout.flush() - self.prefixes = all_data["clip_embedding"] - captions_raw = all_data["captions"] - self.image_ids = [caption["image_id"] for caption in captions_raw] - self.captions = [caption['caption'] for caption in captions_raw] - if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"): - with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f: - self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f) - else: - self.captions_tokens = [] - self.caption2embedding = [] - max_seq_len = 0 - for caption in captions_raw: - self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64)) - self.caption2embedding.append(caption["clip_embedding"]) - max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0]) - # self.max_seq_len = max_seq_len - with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f: - pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f) - all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float() - self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max())) class MLP(nn.Module):