Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 148 additions & 55 deletions src/sequifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch._dynamo
from beartype import beartype
from torch import Tensor, nn
from torch.amp import GradScaler
from torch.nn import ModuleDict, TransformerEncoder, TransformerEncoderLayer
from torch.nn.functional import one_hot
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.checkpoint import checkpoint
from torch.utils.checkpoint import checkpoint


torch._dynamo.config.suppress_errors = True
Expand Down Expand Up @@ -130,34 +131,80 @@ def format_number(number: Union[int, float, np.float32]) -> str:
return f"{number_adjusted:5.2f}e{order_of_magnitude}"



class BlockDiagonalLinear(nn.Module):
"""
A block-diagonal linear layer implemented with a grouped 1D convolution.

This layer is equivalent to applying a separate Dense layer to each
block of features in the input.
"""
def __init__(self, num_blocks: int, in_features_per_block: int, out_features_per_block: int):
super().__init__()
self.num_blocks = num_blocks

# Calculate total input and output features
total_in_features = num_blocks * in_features_per_block
total_out_features = num_blocks * out_features_per_block

# The core of the implementation is a Conv1d layer with groups.
# - kernel_size=1 makes it a fully-connected operation on the channel dimension.
# - groups=num_blocks ensures that each block is processed independently.
self.conv = nn.Conv1d(
in_channels=total_in_features,
out_channels=total_out_features,
kernel_size=1,
groups=num_blocks
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_blocks * in_features_per_block)
"""
# Add a dummy spatial dimension for Conv1d: (N, C) -> (N, C, 1)
x = x.unsqueeze(2)

# Apply the grouped convolution
output = self.conv(x)

# Remove the dummy spatial dimension: (N, C_out, 1) -> (N, C_out)
output = output.squeeze(2)

return output

class GroupedFFN(nn.Module):
"""
Two FFN modes:
- shared_per_region: one FFN shared across regions, applied per region with LN over Dh
- shared_global : one FFN over concatenated regions with LN over R*Dh
Three FFN modes:
- per_region : a separate FFN and LayerNorm for each region.
- shared_per_region : one FFN and LayerNorm shared across all regions.
- shared_global : one FFN and LayerNorm over the concatenated regions.
"""
def __init__(
self,
num_regions: int,
d_head: int,
mult: int = 4,
dropout: float = 0.1,
mode: str = "shared_per_region", # or "shared_global"
mode: str = "",
):
super().__init__()
assert mode in {"shared_per_region", "shared_global"}
assert mode in {"per_region", "shared_per_region", "shared_global"}
self.R = num_regions
self.Dh = d_head
self.mode = mode
self.drop = nn.Dropout(dropout)

if mode == "shared_per_region":
# LN over Dh, then FFN(Dh -> mult*Dh -> Dh), SAME weights for all regions
if mode == "per_region":
# 💡 Apply LayerNorm per region (on d_head)
self.ln = nn.LayerNorm(d_head)
self.ff1 = BlockDiagonalLinear(num_regions, d_head, mult * d_head)
self.ff2 = BlockDiagonalLinear(num_regions, mult * d_head, d_head)
elif mode == "shared_per_region":
self.ln = nn.LayerNorm(d_head)
self.ff1 = nn.Linear(d_head, mult * d_head)
self.ff2 = nn.Linear(mult * d_head, d_head)
else:
# LN over R*Dh, then FFN(R*Dh -> mult*R*Dh -> R*Dh)
elif mode == "shared_global":
d_model = num_regions * d_head
self.ln = nn.LayerNorm(d_model)
self.ff1 = nn.Linear(d_model, mult * d_model)
Expand All @@ -171,26 +218,52 @@ def forward(self, y: torch.Tensor) -> torch.Tensor:
B, T, D = y.shape
assert D == self.R * self.Dh, f"got D={D}, expected R*Dh={self.R*self.Dh}"

if self.mode == "shared_per_region":
# reshape to apply LN/FFN per region (same weights for all regions)
y_per = y.view(B, T, self.R, self.Dh).contiguous() # [B, T, R, Dh]
y_per = self.ln(y_per) # LN over Dh
z = self.ff1(y_per) # (..., Dh)->(..., mult*Dh)
if self.mode == "per_region":
# Reshape to expose regions for LayerNorm
# [B, T, R*Dh] -> [B, T, R, Dh]
y_per = y.view(B, T, self.R, self.Dh)

# Apply LayerNorm independently to each region's Dh features
y_norm = self.ln(y_per)

# Reshape for BlockDiagonalLinear which expects a flattened batch/time dim
# [B, T, R, Dh] -> [B*T, R*Dh]
y_flat = y_norm.view(B * T, D)

z = self.ff1(y_flat)
z = F.gelu(z)
z = self.drop(z)
out = self.ff2(z) # (..., mult*Dh)->(..., Dh)
out_flat = self.ff2(z)
out_flat = self.drop(out_flat)

# Reshape back to the original tensor shape
out = out_flat.view(B, T, D)
return out

elif self.mode == "shared_global":
y_glob = self.ln(y)
z = self.ff1(y_glob)
z = F.gelu(z)
z = self.drop(z)
out = self.ff2(z)
out = self.drop(out)
out = out.view(B, T, self.R * self.Dh).contiguous() # back to [B, T, R*Dh]
return out

else: # "shared_global"
# treat all regions as one big vector
y_glob = self.ln(y) # LN over R*Dh
z = self.ff1(y_glob) # [B, T, R*Dh] -> [B, T, mult*R*Dh]
elif self.mode == "shared_per_region":
# Reshape to apply shared layers per region
# [B, T, R*Dh] -> [B, T, R, Dh]
y_per = y.view(B, T, self.R, self.Dh)

y_norm = self.ln(y_per)
z = self.ff1(y_norm)
z = F.gelu(z)
z = self.drop(z)
out = self.ff2(z) # -> [B, T, R*Dh]
out = self.ff2(z)
out = self.drop(out)

# Reshape back to the original tensor shape
# [B, T, R, Dh] -> [B, T, R*Dh]
out = out.view(B, T, self.R * self.Dh)
return out

class RegionAttention(nn.Module):
Expand Down Expand Up @@ -370,8 +443,9 @@ def forward(self, src: Tensor) -> Tensor:
residual = output
x = self.ln1(output)

# A single call to the efficient parallel layer
attn_out = layer(x)
# A single call to the efficient parallel layer, with checkpointing
# Trades computation for memory by not storing intermediate activations
attn_out = checkpoint(layer, x, use_reentrant=False)

attn_out = self.drop(attn_out)

Expand Down Expand Up @@ -454,6 +528,24 @@ def __init__(self, hparams: Any):
self.pos_encoder[col] = nn.Embedding(
self.seq_length, self.d_model_by_column[col]
)


self.decoder = ModuleDict()
self.softmax = ModuleDict()
for target_column, target_column_type in self.target_column_types.items():
if target_column_type == "categorical":
self.decoder[target_column] = nn.Linear(
self.embedding_size,
self.n_classes[target_column],
)
self.softmax[target_column] = nn.LogSoftmax(dim=-1)
elif target_column_type == "real":
self.decoder[target_column] = nn.Linear(self.embedding_size, 1)
else:
raise ValueError(
f"Target column type {target_column_type} not in ['categorical', 'real']"
)

if self.use_cross_attention:
self.R = len(self.real_columns)
self.d_col = self.d_model_by_column[self.real_columns[0]]
Expand Down Expand Up @@ -498,22 +590,6 @@ def __init__(self, hparams: Any):
)


self.decoder = ModuleDict()
self.softmax = ModuleDict()
for target_column, target_column_type in self.target_column_types.items():
if target_column_type == "categorical":
self.decoder[target_column] = nn.Linear(
self.embedding_size,
self.n_classes[target_column],
)
self.softmax[target_column] = nn.LogSoftmax(dim=-1)
elif target_column_type == "real":
self.decoder[target_column] = nn.Linear(self.embedding_size, 1)
else:
raise ValueError(
f"Target column type {target_column_type} not in ['categorical', 'real']"
)

self.device = hparams.training_spec.device
self.device_max_concat_length = hparams.training_spec.device_max_concat_length
self.criterion = self._init_criterion(hparams=hparams)
Expand All @@ -540,6 +616,10 @@ def __init__(self, hparams: Any):
**self._filter_key(hparams.training_spec.scheduler, "name")
)

# Initialize the gradient scaler for Automatic Mixed Precision (AMP)
# It's enabled only when training on a CUDA device.
self.scaler = GradScaler(enabled=(self.device == 'cuda'))

self.iter_save = hparams.training_spec.iter_save
self.continue_training = hparams.training_spec.continue_training
load_string = self._load_weights_conditional() #checkpoint loading
Expand Down Expand Up @@ -910,19 +990,25 @@ def _train_epoch(
)
output = self.forward_train(data) # [B, T, 1] for each target column - already decoded

# loss should already reflect your intended reduction/masking
loss, losses = self._calculate_loss(output, targets)
# AMP: Enter autocast context for forward pass and loss calculation.
# This enables automatic casting to float16 for performance.
with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(self.device == 'cuda')):
output = self.forward_train(data)
loss, losses = self._calculate_loss(output, targets)

# keep training dynamics IDENTICAL: do NOT rescale by accumulation_steps here
loss.backward()
# AMP: Scale the loss before backward pass to prevent gradient underflow.
self.scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)

if (
self.accumulation_steps is None
or (batch_count + 1) % self.accumulation_steps == 0
or (batch_count + 1) == num_batches
):
self.optimizer.step()
# AMP: Unscales gradients and steps the optimizer.
self.scaler.step(self.optimizer)
# AMP: Updates the scale for the next iteration.
self.scaler.update()
self.optimizer.zero_grad()

# --- bookkeeping ---
Expand All @@ -942,8 +1028,9 @@ def _train_epoch(
)
log_loss_sum = 0.0
start_time = time.time()

del data, targets, output, loss, losses

# --- return proper epoch-average training loss (NEW) ---
epoch_avg_loss = epoch_loss_sum / max(1, epoch_loss_count)
return np.float32(epoch_avg_loss)

Expand Down Expand Up @@ -1006,24 +1093,26 @@ def _evaluate(
self.eval() # turn on evaluation mode

with torch.no_grad():
num_batches = math.ceil(

num_batches = max(1, math.floor(
X_valid[self.target_columns[0]].shape[0] / self.batch_size
) # any column will do
)) # any column will do
total_loss_collect, total_losses_collect = [], []
for batch_start in range(0, num_batches * self.batch_size, self.batch_size):
data, targets = self._get_batch(
X_valid,
y_valid,
batch_start,
batch_start + self.batch_size,
self.batch_size,
to_device=True,
)
output = self.forward_train(data)
total_loss_iter, total_losses_iter = self._calculate_loss(
output, targets
)
total_loss_collect.append(total_loss_iter.cpu())
total_losses_collect.append(total_losses_iter)
with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=(self.device == 'cuda')):
output = self.forward_train(data)
total_loss_iter, total_losses_iter = self._calculate_loss(
output, targets
)
total_loss_collect.append(total_loss_iter.cpu())
total_losses_collect.append(total_losses_iter)

torch.cuda.empty_cache()

Expand Down Expand Up @@ -1172,6 +1261,7 @@ def _save(self, epoch: int, val_loss: np.float32, train_loss: np.float32) -> Non
"epoch": epoch,
"model_state_dict": self.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"loss": val_loss,
},
output_path,
Expand Down Expand Up @@ -1231,6 +1321,9 @@ def _load_weights_conditional(self) -> str:
self.load_state_dict(checkpoint["model_state_dict"])
self.start_epoch = checkpoint["epoch"] + 1
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load scaler state if it exists in the checkpoint, for resuming AMP training
if "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
return f"Loading model weights from {latest_model_path}. Total params: {format_number(pytorch_total_params)}"
else:
self.start_epoch = 1
Expand Down
2 changes: 1 addition & 1 deletion tools/check_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def check_equivalence():
4. Asserting that their outputs are numerically very close.
"""
# 1. Define hyperparameters for the test models
B, T, R, D_EMBED, N_LAYERS, DROP = 100, 120, 16, 64, 2, 0.1
B, T, R, D_EMBED, N_LAYERS, DROP = 10, 30, 16, 64, 4, 0.1
NUM_HEADS = R
D_MODEL = R * D_EMBED
D_HEAD = D_EMBED # In this architecture, d_head equals d_embed
Expand Down