Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tag_llm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Dict, Optional, Union

import torch
from torch_geometric.data import Data
Expand Down Expand Up @@ -38,7 +38,8 @@ def __init__(
self.lm_encoder = LmEncoder(args=lm_encoder_args)

self._parser: Optional[Parser] = None
self._topk = None
self._split_idx: Optional[Dict[str, torch.Tensor]] = None
self._topk: Optional[int] = None

@property
def dataset(self) -> Data:
Expand All @@ -51,6 +52,13 @@ def dataset(self) -> Data:
def num_classes(self) -> int:
return self._parser.graph.n_classes

@property
def split_idx(self) -> Dict[str, torch.Tensor]:
if self._split_idx is None:
_ = self.dataset
self._split_idx = self._parser.graph.split
return self._split_idx

@property
def topk(self) -> int:
"""TopK ranked LLM predictions."""
Expand Down
3 changes: 2 additions & 1 deletion tag_llm/data/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Dict, List, Optional

import torch
from torch_geometric.data import Data


Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self, articles_file_path: Path, cache_dir: Path) -> None:
self.class_labels: Optional[List[ClassLabel]] = None
self.articles: Optional[List[Article]] = None
# Split containing train/val/test node ids
self.split: Optional[Dict] = None
self.split: Optional[Dict[str, torch.Tensor]] = None

@abstractmethod
def load(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tag_llm/data/parser/pubmed/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def load_dataset(self) -> None:
self.dataset.y = self._node_labels
self.dataset.edge_index = self._edge_index

# Split dataset nodes into train/val/test and update the train/val/test masks
# Split dataset nodes into train/valid/test and update the train/valid/test masks
n_nodes = self.dataset.num_nodes
node_ids = torch.randperm(n_nodes)
self.split = {}
for split_name in ('train', 'val', 'test'):
for split_name in ('train', 'valid', 'test'):
if split_name == 'train':
subset = slice(0, int(n_nodes * 0.6))
elif split_name == 'val':
elif split_name == 'valid':
subset = slice(int(n_nodes * 0.6), int(n_nodes * 0.8))
else:
subset = slice(int(n_nodes * 0.8), n_nodes)
Expand All @@ -151,4 +151,4 @@ def load_dataset(self) -> None:
mask = torch.zeros(n_nodes, dtype=bool)
mask[ids] = True
setattr(self.dataset, f'{split_name}_mask', mask)
self.split[split_name] = ids.tolist()
self.split[split_name] = ids
136 changes: 117 additions & 19 deletions tag_llm/trainer/gnn_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from dataclasses import dataclass
from typing import Literal, Optional
from enum import Enum
from typing import Optional

import torch
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from tqdm import tqdm

from tag_llm.data.dataset import GraphDataset
from tag_llm.gnn_model import NodeClassifier, NodeClassifierArgs
Expand All @@ -14,22 +18,36 @@ class GnnTrainerArgs:
lr: float
weight_decay: float = 0.0
early_stopping_patience: int = 50
batch_size: Optional[int] = None # Mini-batch training
num_neighbors: Optional[int] = None # Mini-batch training
num_workers: Optional[int] = None # Mini-batch training
device: Optional[str] = None

def __post_init__(self) -> None:
self.is_mini_batch_training = self.batch_size is not None
if self.is_mini_batch_training and not self.num_neighbors:
print('`gnn_trainer.num_neighbors` was not provided. Using the default value of 10.')
self.num_neighbors = 10
self.num_workers = self.num_workers or os.cpu_count()

@dataclass
class GnnTrainerOutput:
loss: float
accuracy: float
logits: torch.Tensor

class TrainingStage(str, Enum):
TRAIN = 'train'
VALID = 'valid'
TEST = 'test'

class GnnTrainer:

class GnnTrainer:
def __init__(self, trainer_args: GnnTrainerArgs, graph_dataset: GraphDataset, model_args: NodeClassifierArgs) -> None:

self.trainer_args = trainer_args
self.model_args = model_args
self.dataset: Data = graph_dataset.dataset
self.split_idx = graph_dataset.split_idx
self.device = trainer_args.device or ('cuda' if torch.cuda.is_available() else 'cpu')

use_predictions = graph_dataset.feature_type == 'prediction'
Expand All @@ -50,17 +68,24 @@ def train(self) -> GnnTrainerOutput:
patience = self.trainer_args.early_stopping_patience
best_val_loss = float('inf')
epochs_without_improvement = 0
dataloaders = None

for epoch in range(1, self.trainer_args.epochs + 1):
train_output = self._train_eval(self.dataset, stage='train')
val_output = self._train_eval(self.dataset, stage='val')
if self.trainer_args.is_mini_batch_training:
if dataloaders is None:
dataloaders = self._get_dataloaders()
train_output = self._train_eval_mini_batch(epoch, dataloaders['train'], TrainingStage.TRAIN)
valid_output = self._train_eval_mini_batch(epoch, dataloaders['valid'], TrainingStage.VALID)
else:
train_output = self._train_eval_full_batch(TrainingStage.TRAIN)
valid_output = self._train_eval_full_batch(TrainingStage.VALID)
print(
f'Epoch: {epoch:03d} | Train loss: {train_output.loss:.4f}, '
f'Val loss: {val_output.loss:.4f}, Train accuracy: {train_output.accuracy:.4f}, '
f'Val accuracy: {val_output.accuracy:.4f}'
f'Valid loss: {valid_output.loss:.4f}, Train accuracy: {train_output.accuracy:.4f}, '
f'Valid accuracy: {valid_output.accuracy:.4f}'
)
if val_output.loss < best_val_loss:
best_val_loss = val_output.loss
if valid_output.loss < best_val_loss:
best_val_loss = valid_output.loss
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
Expand All @@ -69,18 +94,56 @@ def train(self) -> GnnTrainerOutput:
print(f'Early stopping on epoch {epoch} due to no improvement in validation loss for {patience} epochs.')
break

output = self._train_eval(self.dataset, stage='test')
return output

def _train_eval(self, data: Data, stage: Literal['train', 'val', 'test']):
if stage == 'train':
if self.trainer_args.is_mini_batch_training:
test_output = self._train_eval_mini_batch(epoch, dataloaders['test'], TrainingStage.TEST)
else:
test_output = self._train_eval_full_batch(TrainingStage.TEST)
return test_output

def _get_dataloaders(self):
config = self.trainer_args
num_neighbors = [config.num_neighbors] * self.model_args.num_layers
persistent_workers = config.num_workers > 0
train_dataloader = NeighborLoader(
data=self.dataset,
num_neighbors=num_neighbors,
input_nodes=self.split_idx['train'],
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
persistent_workers=persistent_workers,
)
valid_dataloader = NeighborLoader(
data=self.dataset,
num_neighbors=num_neighbors,
input_nodes=self.split_idx['valid'],
batch_size=config.batch_size,
num_workers=config.num_workers,
persistent_workers=persistent_workers,
)
test_dataloader = NeighborLoader(
data=self.dataset,
num_neighbors=num_neighbors,
input_nodes=self.split_idx['test'],
batch_size=config.batch_size,
num_workers=config.num_workers,
persistent_workers=persistent_workers,
)
return dict(train=train_dataloader, valid=valid_dataloader, test=test_dataloader)

def _train_eval_full_batch(self, stage: TrainingStage):
if stage == TrainingStage.TRAIN:
self.model.train()
else:
self.model.eval()

data = data.to(self.device)
mask = getattr(data, f'{stage}_mask')
if stage == 'train':
data = self.dataset.to(self.device)
mask = getattr(data, f'{stage.value}_mask', None)
assert mask, (
'Missing `*_mask` attributes from the dataset! `train_mask`, '
'`valid_mask` and `test_mask` are required for full-batch training.'
)
if stage == TrainingStage.TRAIN:
self.optimizer.zero_grad()
logits = self.model(data.x, data.edge_index)
loss = self.criterion(logits[mask], data.y[mask].flatten())
Expand All @@ -94,8 +157,43 @@ def _train_eval(self, data: Data, stage: Literal['train', 'val', 'test']):
accuracy = GnnTrainer.compute_accuracy(logits, data.y, mask)
return GnnTrainerOutput(loss=float(loss), accuracy=accuracy, logits=logits)

def _train_eval_mini_batch(self, epoch: int, dataloader: NeighborLoader, stage: TrainingStage):
if stage == TrainingStage.TRAIN:
self.model.train()
else:
self.model.eval()

total_loss = total_correct = total_samples = 0
num_batches = len(dataloader)
batch_logits = []
for batch in tqdm(dataloader, total=num_batches, desc=f'{stage.value.capitalize()}ing epoch {epoch}'):
batch = batch.to(self.device)
y = batch.y[:batch.batch_size].flatten()
if stage == TrainingStage.TRAIN:
self.optimizer.zero_grad()
logits = self.model(batch.x, batch.edge_index)[:batch.batch_size]
loss = self.criterion(logits, y)
loss.backward()
self.optimizer.step()
else:
with torch.inference_mode():
logits = self.model(batch.x, batch.edge_index)[:batch.batch_size]
loss = self.criterion(logits, y)

total_loss += float(loss)
total_correct += GnnTrainer.compute_accuracy(logits, y)
total_samples += y.shape[0]
batch_logits.append(logits)

avg_loss = total_loss / num_batches
avg_accuracy = total_correct / total_samples
logits = torch.cat(batch_logits) # full-batch logits
return GnnTrainerOutput(loss=avg_loss, accuracy=avg_accuracy, logits=logits)

@staticmethod
def compute_accuracy(logits: torch.Tensor, y_true: torch.Tensor, mask: torch.Tensor) -> float:
y_pred = logits.argmax(dim=1)
def compute_accuracy(logits: torch.Tensor, y_true: torch.Tensor, mask: Optional[torch.Tensor] = None):
y_pred = logits.argmax(dim=-1)
if mask is None:
return int((y_pred == y_true).sum())
correct = y_pred[mask] == y_true[mask]
return int(correct.sum()) / int(mask.sum())