From 7da0cef93aeeaa6d0fe8ad0e8f5e3eab99bf0cc6 Mon Sep 17 00:00:00 2001 From: Devansh Amin Date: Fri, 21 Jun 2024 21:15:31 -0400 Subject: [PATCH 1/3] Rename split stage and add proper type hint --- tag_llm/data/parser/base.py | 3 ++- tag_llm/data/parser/pubmed/graph.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tag_llm/data/parser/base.py b/tag_llm/data/parser/base.py index 3573029..1a7d34c 100644 --- a/tag_llm/data/parser/base.py +++ b/tag_llm/data/parser/base.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict, List, Optional +import torch from torch_geometric.data import Data @@ -57,7 +58,7 @@ def __init__(self, articles_file_path: Path, cache_dir: Path) -> None: self.class_labels: Optional[List[ClassLabel]] = None self.articles: Optional[List[Article]] = None # Split containing train/val/test node ids - self.split: Optional[Dict] = None + self.split: Optional[Dict[str, torch.Tensor]] = None @abstractmethod def load(self) -> None: diff --git a/tag_llm/data/parser/pubmed/graph.py b/tag_llm/data/parser/pubmed/graph.py index a10b3b4..635baaa 100644 --- a/tag_llm/data/parser/pubmed/graph.py +++ b/tag_llm/data/parser/pubmed/graph.py @@ -134,14 +134,14 @@ def load_dataset(self) -> None: self.dataset.y = self._node_labels self.dataset.edge_index = self._edge_index - # Split dataset nodes into train/val/test and update the train/val/test masks + # Split dataset nodes into train/valid/test and update the train/valid/test masks n_nodes = self.dataset.num_nodes node_ids = torch.randperm(n_nodes) self.split = {} - for split_name in ('train', 'val', 'test'): + for split_name in ('train', 'valid', 'test'): if split_name == 'train': subset = slice(0, int(n_nodes * 0.6)) - elif split_name == 'val': + elif split_name == 'valid': subset = slice(int(n_nodes * 0.6), int(n_nodes * 0.8)) else: subset = slice(int(n_nodes * 0.8), n_nodes) @@ -151,4 +151,4 @@ def load_dataset(self) -> None: mask = torch.zeros(n_nodes, dtype=bool) mask[ids] = True setattr(self.dataset, f'{split_name}_mask', mask) - self.split[split_name] = ids.tolist() + self.split[split_name] = ids From 59dced9b927cda59316d43df66881c84be5ec89b Mon Sep 17 00:00:00 2001 From: Devansh Amin Date: Fri, 21 Jun 2024 21:16:31 -0400 Subject: [PATCH 2/3] Add property `split_idx` --- tag_llm/data/dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tag_llm/data/dataset.py b/tag_llm/data/dataset.py index 30be9a3..4fc0f76 100644 --- a/tag_llm/data/dataset.py +++ b/tag_llm/data/dataset.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Dict, Optional, Union import torch from torch_geometric.data import Data @@ -38,7 +38,8 @@ def __init__( self.lm_encoder = LmEncoder(args=lm_encoder_args) self._parser: Optional[Parser] = None - self._topk = None + self._split_idx: Optional[Dict[str, torch.Tensor]] = None + self._topk: Optional[int] = None @property def dataset(self) -> Data: @@ -51,6 +52,13 @@ def dataset(self) -> Data: def num_classes(self) -> int: return self._parser.graph.n_classes + @property + def split_idx(self) -> Dict[str, torch.Tensor]: + if self._split_idx is None: + _ = self.dataset + self._split_idx = self._parser.graph.split + return self._split_idx + @property def topk(self) -> int: """TopK ranked LLM predictions.""" From 735c7512c9472dafbf5c8bf940e44f0df7fc7992 Mon Sep 17 00:00:00 2001 From: Devansh Amin Date: Fri, 21 Jun 2024 21:18:31 -0400 Subject: [PATCH 3/3] Add support for mini-batch training --- tag_llm/trainer/gnn_trainer.py | 136 ++++++++++++++++++++++++++++----- 1 file changed, 117 insertions(+), 19 deletions(-) diff --git a/tag_llm/trainer/gnn_trainer.py b/tag_llm/trainer/gnn_trainer.py index afa4274..fde5b89 100644 --- a/tag_llm/trainer/gnn_trainer.py +++ b/tag_llm/trainer/gnn_trainer.py @@ -1,8 +1,12 @@ +import os from dataclasses import dataclass -from typing import Literal, Optional +from enum import Enum +from typing import Optional import torch from torch_geometric.data import Data +from torch_geometric.loader import NeighborLoader +from tqdm import tqdm from tag_llm.data.dataset import GraphDataset from tag_llm.gnn_model import NodeClassifier, NodeClassifierArgs @@ -14,8 +18,17 @@ class GnnTrainerArgs: lr: float weight_decay: float = 0.0 early_stopping_patience: int = 50 + batch_size: Optional[int] = None # Mini-batch training + num_neighbors: Optional[int] = None # Mini-batch training + num_workers: Optional[int] = None # Mini-batch training device: Optional[str] = None + def __post_init__(self) -> None: + self.is_mini_batch_training = self.batch_size is not None + if self.is_mini_batch_training and not self.num_neighbors: + print('`gnn_trainer.num_neighbors` was not provided. Using the default value of 10.') + self.num_neighbors = 10 + self.num_workers = self.num_workers or os.cpu_count() @dataclass class GnnTrainerOutput: @@ -23,13 +36,18 @@ class GnnTrainerOutput: accuracy: float logits: torch.Tensor +class TrainingStage(str, Enum): + TRAIN = 'train' + VALID = 'valid' + TEST = 'test' -class GnnTrainer: +class GnnTrainer: def __init__(self, trainer_args: GnnTrainerArgs, graph_dataset: GraphDataset, model_args: NodeClassifierArgs) -> None: - self.trainer_args = trainer_args + self.model_args = model_args self.dataset: Data = graph_dataset.dataset + self.split_idx = graph_dataset.split_idx self.device = trainer_args.device or ('cuda' if torch.cuda.is_available() else 'cpu') use_predictions = graph_dataset.feature_type == 'prediction' @@ -50,17 +68,24 @@ def train(self) -> GnnTrainerOutput: patience = self.trainer_args.early_stopping_patience best_val_loss = float('inf') epochs_without_improvement = 0 + dataloaders = None for epoch in range(1, self.trainer_args.epochs + 1): - train_output = self._train_eval(self.dataset, stage='train') - val_output = self._train_eval(self.dataset, stage='val') + if self.trainer_args.is_mini_batch_training: + if dataloaders is None: + dataloaders = self._get_dataloaders() + train_output = self._train_eval_mini_batch(epoch, dataloaders['train'], TrainingStage.TRAIN) + valid_output = self._train_eval_mini_batch(epoch, dataloaders['valid'], TrainingStage.VALID) + else: + train_output = self._train_eval_full_batch(TrainingStage.TRAIN) + valid_output = self._train_eval_full_batch(TrainingStage.VALID) print( f'Epoch: {epoch:03d} | Train loss: {train_output.loss:.4f}, ' - f'Val loss: {val_output.loss:.4f}, Train accuracy: {train_output.accuracy:.4f}, ' - f'Val accuracy: {val_output.accuracy:.4f}' + f'Valid loss: {valid_output.loss:.4f}, Train accuracy: {train_output.accuracy:.4f}, ' + f'Valid accuracy: {valid_output.accuracy:.4f}' ) - if val_output.loss < best_val_loss: - best_val_loss = val_output.loss + if valid_output.loss < best_val_loss: + best_val_loss = valid_output.loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 @@ -69,18 +94,56 @@ def train(self) -> GnnTrainerOutput: print(f'Early stopping on epoch {epoch} due to no improvement in validation loss for {patience} epochs.') break - output = self._train_eval(self.dataset, stage='test') - return output - - def _train_eval(self, data: Data, stage: Literal['train', 'val', 'test']): - if stage == 'train': + if self.trainer_args.is_mini_batch_training: + test_output = self._train_eval_mini_batch(epoch, dataloaders['test'], TrainingStage.TEST) + else: + test_output = self._train_eval_full_batch(TrainingStage.TEST) + return test_output + + def _get_dataloaders(self): + config = self.trainer_args + num_neighbors = [config.num_neighbors] * self.model_args.num_layers + persistent_workers = config.num_workers > 0 + train_dataloader = NeighborLoader( + data=self.dataset, + num_neighbors=num_neighbors, + input_nodes=self.split_idx['train'], + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + persistent_workers=persistent_workers, + ) + valid_dataloader = NeighborLoader( + data=self.dataset, + num_neighbors=num_neighbors, + input_nodes=self.split_idx['valid'], + batch_size=config.batch_size, + num_workers=config.num_workers, + persistent_workers=persistent_workers, + ) + test_dataloader = NeighborLoader( + data=self.dataset, + num_neighbors=num_neighbors, + input_nodes=self.split_idx['test'], + batch_size=config.batch_size, + num_workers=config.num_workers, + persistent_workers=persistent_workers, + ) + return dict(train=train_dataloader, valid=valid_dataloader, test=test_dataloader) + + def _train_eval_full_batch(self, stage: TrainingStage): + if stage == TrainingStage.TRAIN: self.model.train() else: self.model.eval() - data = data.to(self.device) - mask = getattr(data, f'{stage}_mask') - if stage == 'train': + data = self.dataset.to(self.device) + mask = getattr(data, f'{stage.value}_mask', None) + assert mask, ( + 'Missing `*_mask` attributes from the dataset! `train_mask`, ' + '`valid_mask` and `test_mask` are required for full-batch training.' + ) + if stage == TrainingStage.TRAIN: self.optimizer.zero_grad() logits = self.model(data.x, data.edge_index) loss = self.criterion(logits[mask], data.y[mask].flatten()) @@ -94,8 +157,43 @@ def _train_eval(self, data: Data, stage: Literal['train', 'val', 'test']): accuracy = GnnTrainer.compute_accuracy(logits, data.y, mask) return GnnTrainerOutput(loss=float(loss), accuracy=accuracy, logits=logits) + def _train_eval_mini_batch(self, epoch: int, dataloader: NeighborLoader, stage: TrainingStage): + if stage == TrainingStage.TRAIN: + self.model.train() + else: + self.model.eval() + + total_loss = total_correct = total_samples = 0 + num_batches = len(dataloader) + batch_logits = [] + for batch in tqdm(dataloader, total=num_batches, desc=f'{stage.value.capitalize()}ing epoch {epoch}'): + batch = batch.to(self.device) + y = batch.y[:batch.batch_size].flatten() + if stage == TrainingStage.TRAIN: + self.optimizer.zero_grad() + logits = self.model(batch.x, batch.edge_index)[:batch.batch_size] + loss = self.criterion(logits, y) + loss.backward() + self.optimizer.step() + else: + with torch.inference_mode(): + logits = self.model(batch.x, batch.edge_index)[:batch.batch_size] + loss = self.criterion(logits, y) + + total_loss += float(loss) + total_correct += GnnTrainer.compute_accuracy(logits, y) + total_samples += y.shape[0] + batch_logits.append(logits) + + avg_loss = total_loss / num_batches + avg_accuracy = total_correct / total_samples + logits = torch.cat(batch_logits) # full-batch logits + return GnnTrainerOutput(loss=avg_loss, accuracy=avg_accuracy, logits=logits) + @staticmethod - def compute_accuracy(logits: torch.Tensor, y_true: torch.Tensor, mask: torch.Tensor) -> float: - y_pred = logits.argmax(dim=1) + def compute_accuracy(logits: torch.Tensor, y_true: torch.Tensor, mask: Optional[torch.Tensor] = None): + y_pred = logits.argmax(dim=-1) + if mask is None: + return int((y_pred == y_true).sum()) correct = y_pred[mask] == y_true[mask] return int(correct.sum()) / int(mask.sum())