From 0cc964c61ab8a389684fe1383e6239096f8a1792 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Mon, 6 Oct 2025 11:08:13 +0200 Subject: [PATCH 1/3] WIP --- src/sequifier/train.py | 83 ++++++++++++++++++++++++-------------- tools/check_equivalence.py | 2 +- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/src/sequifier/train.py b/src/sequifier/train.py index 493a8d22..ad7cd316 100644 --- a/src/sequifier/train.py +++ b/src/sequifier/train.py @@ -13,11 +13,12 @@ import torch._dynamo from beartype import beartype from torch import Tensor, nn +from torch.amp import GradScaler from torch.nn import ModuleDict, TransformerEncoder, TransformerEncoderLayer from torch.nn.functional import one_hot import torch.nn.functional as F import torch.nn.init as init -from torch.utils.checkpoint import checkpoint +from torch.utils.checkpoint import checkpoint torch._dynamo.config.suppress_errors = True @@ -370,8 +371,9 @@ def forward(self, src: Tensor) -> Tensor: residual = output x = self.ln1(output) - # A single call to the efficient parallel layer - attn_out = layer(x) + # A single call to the efficient parallel layer, with checkpointing + # Trades computation for memory by not storing intermediate activations + attn_out = checkpoint(layer, x, use_reentrant=False) attn_out = self.drop(attn_out) @@ -454,6 +456,24 @@ def __init__(self, hparams: Any): self.pos_encoder[col] = nn.Embedding( self.seq_length, self.d_model_by_column[col] ) + + + self.decoder = ModuleDict() + self.softmax = ModuleDict() + for target_column, target_column_type in self.target_column_types.items(): + if target_column_type == "categorical": + self.decoder[target_column] = nn.Linear( + self.embedding_size, + self.n_classes[target_column], + ) + self.softmax[target_column] = nn.LogSoftmax(dim=-1) + elif target_column_type == "real": + self.decoder[target_column] = nn.Linear(self.embedding_size, 1) + else: + raise ValueError( + f"Target column type {target_column_type} not in ['categorical', 'real']" + ) + if self.use_cross_attention: self.R = len(self.real_columns) self.d_col = self.d_model_by_column[self.real_columns[0]] @@ -498,22 +518,6 @@ def __init__(self, hparams: Any): ) - self.decoder = ModuleDict() - self.softmax = ModuleDict() - for target_column, target_column_type in self.target_column_types.items(): - if target_column_type == "categorical": - self.decoder[target_column] = nn.Linear( - self.embedding_size, - self.n_classes[target_column], - ) - self.softmax[target_column] = nn.LogSoftmax(dim=-1) - elif target_column_type == "real": - self.decoder[target_column] = nn.Linear(self.embedding_size, 1) - else: - raise ValueError( - f"Target column type {target_column_type} not in ['categorical', 'real']" - ) - self.device = hparams.training_spec.device self.device_max_concat_length = hparams.training_spec.device_max_concat_length self.criterion = self._init_criterion(hparams=hparams) @@ -540,6 +544,10 @@ def __init__(self, hparams: Any): **self._filter_key(hparams.training_spec.scheduler, "name") ) + # Initialize the gradient scaler for Automatic Mixed Precision (AMP) + # It's enabled only when training on a CUDA device. + self.scaler = GradScaler(enabled=(self.device == 'cuda')) + self.iter_save = hparams.training_spec.iter_save self.continue_training = hparams.training_spec.continue_training load_string = self._load_weights_conditional() #checkpoint loading @@ -910,11 +918,14 @@ def _train_epoch( ) output = self.forward_train(data) # [B, T, 1] for each target column - already decoded - # loss should already reflect your intended reduction/masking - loss, losses = self._calculate_loss(output, targets) + # AMP: Enter autocast context for forward pass and loss calculation. + # This enables automatic casting to float16 for performance. + with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(self.device == 'cuda')): + output = self.forward_train(data) + loss, losses = self._calculate_loss(output, targets) - # keep training dynamics IDENTICAL: do NOT rescale by accumulation_steps here - loss.backward() + # AMP: Scale the loss before backward pass to prevent gradient underflow. + self.scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) if ( @@ -922,7 +933,10 @@ def _train_epoch( or (batch_count + 1) % self.accumulation_steps == 0 or (batch_count + 1) == num_batches ): - self.optimizer.step() + # AMP: Unscales gradients and steps the optimizer. + self.scaler.step(self.optimizer) + # AMP: Updates the scale for the next iteration. + self.scaler.update() self.optimizer.zero_grad() # --- bookkeeping --- @@ -942,8 +956,9 @@ def _train_epoch( ) log_loss_sum = 0.0 start_time = time.time() + + del data, targets, output, loss, losses - # --- return proper epoch-average training loss (NEW) --- epoch_avg_loss = epoch_loss_sum / max(1, epoch_loss_count) return np.float32(epoch_avg_loss) @@ -1006,6 +1021,7 @@ def _evaluate( self.eval() # turn on evaluation mode with torch.no_grad(): + num_batches = math.ceil( X_valid[self.target_columns[0]].shape[0] / self.batch_size ) # any column will do @@ -1018,12 +1034,13 @@ def _evaluate( batch_start + self.batch_size, to_device=True, ) - output = self.forward_train(data) - total_loss_iter, total_losses_iter = self._calculate_loss( - output, targets - ) - total_loss_collect.append(total_loss_iter.cpu()) - total_losses_collect.append(total_losses_iter) + with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(self.device == 'cuda')): + output = self.forward_train(data) + total_loss_iter, total_losses_iter = self._calculate_loss( + output, targets + ) + total_loss_collect.append(total_loss_iter.cpu()) + total_losses_collect.append(total_losses_iter) torch.cuda.empty_cache() @@ -1172,6 +1189,7 @@ def _save(self, epoch: int, val_loss: np.float32, train_loss: np.float32) -> Non "epoch": epoch, "model_state_dict": self.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), + "scaler_state_dict": self.scaler.state_dict(), "loss": val_loss, }, output_path, @@ -1231,6 +1249,9 @@ def _load_weights_conditional(self) -> str: self.load_state_dict(checkpoint["model_state_dict"]) self.start_epoch = checkpoint["epoch"] + 1 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + # Load scaler state if it exists in the checkpoint, for resuming AMP training + if "scaler_state_dict" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) return f"Loading model weights from {latest_model_path}. Total params: {format_number(pytorch_total_params)}" else: self.start_epoch = 1 diff --git a/tools/check_equivalence.py b/tools/check_equivalence.py index 635b61df..7e7cb620 100644 --- a/tools/check_equivalence.py +++ b/tools/check_equivalence.py @@ -27,7 +27,7 @@ def check_equivalence(): 4. Asserting that their outputs are numerically very close. """ # 1. Define hyperparameters for the test models - B, T, R, D_EMBED, N_LAYERS, DROP = 100, 120, 16, 64, 2, 0.1 + B, T, R, D_EMBED, N_LAYERS, DROP = 10, 30, 16, 64, 4, 0.1 NUM_HEADS = R D_MODEL = R * D_EMBED D_HEAD = D_EMBED # In this architecture, d_head equals d_embed From 7ecb5b079925a27246d476a43cbb4b8562d86303 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Mon, 6 Oct 2025 11:19:19 +0200 Subject: [PATCH 2/3] set floor --- src/sequifier/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sequifier/train.py b/src/sequifier/train.py index ad7cd316..0f339cc1 100644 --- a/src/sequifier/train.py +++ b/src/sequifier/train.py @@ -1022,16 +1022,16 @@ def _evaluate( with torch.no_grad(): - num_batches = math.ceil( + num_batches = max(1, math.floor( X_valid[self.target_columns[0]].shape[0] / self.batch_size - ) # any column will do + )) # any column will do total_loss_collect, total_losses_collect = [], [] for batch_start in range(0, num_batches * self.batch_size, self.batch_size): data, targets = self._get_batch( X_valid, y_valid, batch_start, - batch_start + self.batch_size, + self.batch_size, to_device=True, ) with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(self.device == 'cuda')): From 5a923ee659a51af40c7b3767882d203ae16ece09 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Wed, 8 Oct 2025 19:42:56 +0200 Subject: [PATCH 3/3] add per_region ffn --- src/sequifier/train.py | 114 +++++++++++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 21 deletions(-) diff --git a/src/sequifier/train.py b/src/sequifier/train.py index 0f339cc1..22f957d6 100644 --- a/src/sequifier/train.py +++ b/src/sequifier/train.py @@ -131,11 +131,54 @@ def format_number(number: Union[int, float, np.float32]) -> str: return f"{number_adjusted:5.2f}e{order_of_magnitude}" + +class BlockDiagonalLinear(nn.Module): + """ + A block-diagonal linear layer implemented with a grouped 1D convolution. + + This layer is equivalent to applying a separate Dense layer to each + block of features in the input. + """ + def __init__(self, num_blocks: int, in_features_per_block: int, out_features_per_block: int): + super().__init__() + self.num_blocks = num_blocks + + # Calculate total input and output features + total_in_features = num_blocks * in_features_per_block + total_out_features = num_blocks * out_features_per_block + + # The core of the implementation is a Conv1d layer with groups. + # - kernel_size=1 makes it a fully-connected operation on the channel dimension. + # - groups=num_blocks ensures that each block is processed independently. + self.conv = nn.Conv1d( + in_channels=total_in_features, + out_channels=total_out_features, + kernel_size=1, + groups=num_blocks + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape (batch_size, num_blocks * in_features_per_block) + """ + # Add a dummy spatial dimension for Conv1d: (N, C) -> (N, C, 1) + x = x.unsqueeze(2) + + # Apply the grouped convolution + output = self.conv(x) + + # Remove the dummy spatial dimension: (N, C_out, 1) -> (N, C_out) + output = output.squeeze(2) + + return output + class GroupedFFN(nn.Module): """ - Two FFN modes: - - shared_per_region: one FFN shared across regions, applied per region with LN over Dh - - shared_global : one FFN over concatenated regions with LN over R*Dh + Three FFN modes: + - per_region : a separate FFN and LayerNorm for each region. + - shared_per_region : one FFN and LayerNorm shared across all regions. + - shared_global : one FFN and LayerNorm over the concatenated regions. """ def __init__( self, @@ -143,22 +186,25 @@ def __init__( d_head: int, mult: int = 4, dropout: float = 0.1, - mode: str = "shared_per_region", # or "shared_global" + mode: str = "", ): super().__init__() - assert mode in {"shared_per_region", "shared_global"} + assert mode in {"per_region", "shared_per_region", "shared_global"} self.R = num_regions self.Dh = d_head self.mode = mode self.drop = nn.Dropout(dropout) - if mode == "shared_per_region": - # LN over Dh, then FFN(Dh -> mult*Dh -> Dh), SAME weights for all regions + if mode == "per_region": + # 💡 Apply LayerNorm per region (on d_head) + self.ln = nn.LayerNorm(d_head) + self.ff1 = BlockDiagonalLinear(num_regions, d_head, mult * d_head) + self.ff2 = BlockDiagonalLinear(num_regions, mult * d_head, d_head) + elif mode == "shared_per_region": self.ln = nn.LayerNorm(d_head) self.ff1 = nn.Linear(d_head, mult * d_head) self.ff2 = nn.Linear(mult * d_head, d_head) - else: - # LN over R*Dh, then FFN(R*Dh -> mult*R*Dh -> R*Dh) + elif mode == "shared_global": d_model = num_regions * d_head self.ln = nn.LayerNorm(d_model) self.ff1 = nn.Linear(d_model, mult * d_model) @@ -172,26 +218,52 @@ def forward(self, y: torch.Tensor) -> torch.Tensor: B, T, D = y.shape assert D == self.R * self.Dh, f"got D={D}, expected R*Dh={self.R*self.Dh}" - if self.mode == "shared_per_region": - # reshape to apply LN/FFN per region (same weights for all regions) - y_per = y.view(B, T, self.R, self.Dh).contiguous() # [B, T, R, Dh] - y_per = self.ln(y_per) # LN over Dh - z = self.ff1(y_per) # (..., Dh)->(..., mult*Dh) + if self.mode == "per_region": + # Reshape to expose regions for LayerNorm + # [B, T, R*Dh] -> [B, T, R, Dh] + y_per = y.view(B, T, self.R, self.Dh) + + # Apply LayerNorm independently to each region's Dh features + y_norm = self.ln(y_per) + + # Reshape for BlockDiagonalLinear which expects a flattened batch/time dim + # [B, T, R, Dh] -> [B*T, R*Dh] + y_flat = y_norm.view(B * T, D) + + z = self.ff1(y_flat) z = F.gelu(z) z = self.drop(z) - out = self.ff2(z) # (..., mult*Dh)->(..., Dh) + out_flat = self.ff2(z) + out_flat = self.drop(out_flat) + + # Reshape back to the original tensor shape + out = out_flat.view(B, T, D) + return out + + elif self.mode == "shared_global": + y_glob = self.ln(y) + z = self.ff1(y_glob) + z = F.gelu(z) + z = self.drop(z) + out = self.ff2(z) out = self.drop(out) - out = out.view(B, T, self.R * self.Dh).contiguous() # back to [B, T, R*Dh] return out - else: # "shared_global" - # treat all regions as one big vector - y_glob = self.ln(y) # LN over R*Dh - z = self.ff1(y_glob) # [B, T, R*Dh] -> [B, T, mult*R*Dh] + elif self.mode == "shared_per_region": + # Reshape to apply shared layers per region + # [B, T, R*Dh] -> [B, T, R, Dh] + y_per = y.view(B, T, self.R, self.Dh) + + y_norm = self.ln(y_per) + z = self.ff1(y_norm) z = F.gelu(z) z = self.drop(z) - out = self.ff2(z) # -> [B, T, R*Dh] + out = self.ff2(z) out = self.drop(out) + + # Reshape back to the original tensor shape + # [B, T, R, Dh] -> [B, T, R*Dh] + out = out.view(B, T, self.R * self.Dh) return out class RegionAttention(nn.Module):