Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions open_eeg_bench/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
from typing import Any, Literal
from importlib import import_module
import warnings

from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand Down Expand Up @@ -277,39 +276,54 @@ def load_pretrained(self, model) -> None:
len(skipped),
)

# Sanity check on missing keys (model keys not loaded from checkpoint).
param_names = {name for name, _ in model.named_parameters()}
allowed_prefixes = [self.head_module_name]
if self.training_required_modules:
allowed_prefixes.extend(self.training_required_modules)
allowed_modules = [self.head_module_name] + (
self.training_required_modules or []
)

def _is_allowed(key: str) -> bool:
return any(key == p or key.startswith(p + ".") for p in allowed_prefixes)
def covered(k: str) -> bool:
return any(f".{m}." in f".{k}." for m in allowed_modules)

unexpected_params = [
k for k in missing if k in param_names and not _is_allowed(k)
def describe(k: str) -> str:
if k in state_dict:
ckpt_shape = tuple(state_dict[k].shape)
model_shape = tuple(model_state[k].shape)
return f"{k} [shape mismatch: ckpt {ckpt_shape} vs model {model_shape}]"
return f"{k} [absent from checkpoint]"

param_names = {n for n, _ in model.named_parameters()}
missing_params = [k for k in missing if k in param_names and not covered(k)]
missing_buffers = [
k for k in missing if k not in param_names and not covered(k)
]
missing_buffers = [k for k in missing if k not in param_names]

if unexpected_params:
# TODO:
# - FIX the models.
# - Transform this warning into an error
warnings.warn(
f"Pretrained checkpoint is missing weights for backbone "
f"parameters outside of head_module_name and "
f"training_required_modules: {unexpected_params}. "
f"These parameters keep random initialization, making results "
f"seed-dependent. Either provide a checkpoint that covers them "
f"or add the relevant modules to training_required_modules.",
stacklevel=2,

if missing_params:
raise ValueError(
f"Pretrained checkpoint for {self.model_cls} is missing "
f"{len(missing_params)} learnable parameter(s) that are neither "
f"under head_module_name='{self.head_module_name}' nor under any "
f"module in training_required_modules={self.training_required_modules}:\n"
+ "\n".join(f" - {describe(k)}" for k in missing_params)
+ "\nThese parameters would be silently initialized from scratch, "
"which is almost certainly a config error. Either:\n"
" (a) the checkpoint is incompatible with the declared architecture "
"(check model_kwargs), or\n"
" (b) these modules should be declared in `training_required_modules` "
"so they are explicitly trained from scratch (note: this will "
"categorize the model separately in evaluations)."
)

if missing_buffers:
warnings.warn(
f"Pretrained checkpoint does not cover the following "
f"buffers: {missing_buffers}. They keep their default "
f"initialization (typically deterministic).",
stacklevel=2,
log.warning(
"Pretrained checkpoint for %s is missing %d buffer(s) that are "
"neither under head_module_name='%s' nor under any module in "
"training_required_modules=%s:\n%s\n"
"Buffers are not trained, so missing values may be computed at "
"init — but verify this is intentional.",
self.model_cls,
len(missing_buffers),
self.head_module_name,
self.training_required_modules,
"\n".join(f" - {describe(k)}" for k in missing_buffers),
)

@staticmethod
Expand Down
97 changes: 5 additions & 92 deletions open_eeg_bench/default_configs/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
def biot(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.BIOT",
model_kwargs=dict(
embed_dim=256,
num_heads=8,
num_layers=4,
drop_prob=0.5,
att_drop_prob=0.2,
att_layer_drop_prob=0.2,
hop_length=100,
max_seq_len=1024,
return_feature=False,
),
# w1, w2 are the two Linear layers inside FeedForward blocks
peft_ff_modules=["w1", "w2"],
normalization=PercentileScale(q=95.0),
Expand All @@ -36,25 +25,10 @@ def biot(**overrides) -> PretrainedBackbone:
def labram(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.Labram",
model_kwargs=dict(
patch_size=200,
embed_dim=200,
num_layers=12,
num_heads=10,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_prob=0.1,
attn_drop_prob=0.0,
drop_path_prob=0.1,
use_abs_pos_emb=True,
use_mean_pooling=True,
init_scale=0.001,
init_values=0.1,
neural_tokenizer=True,
),
# mlp.0 and mlp.2 are the two Linear layers inside the MLP block
peft_ff_modules=["mlp.0", "mlp.2"],
# The temporal_embedding was trained for n_times=3000 only (so there will be a shape mismatch)
training_required_modules=["temporal_embedding"],
normalization=DivideByConstant(factor=100.0),
hub_repo="braindecode/labram-pretrained",
)
Expand All @@ -64,26 +38,9 @@ def labram(**overrides) -> PretrainedBackbone:

def bendr(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.BENDR",
model_kwargs=dict(
encoder_h=512,
contextualizer_hidden=3076,
transformer_layers=8,
transformer_heads=8,
position_encoder_length=25,
enc_width=(3, 2, 2, 2, 2, 2),
enc_downsample=(3, 2, 2, 2, 2, 2),
drop_prob=0.1,
layer_drop=0.0,
projection_head=False,
start_token=-5,
final_layer=True,
n_chans_pretrained=20,
encoder_only=True,
),
model_cls="braindecode.models.InterpolatedBENDR",
# linear1, linear2 are the FFN layers in TransformerEncoderLayer
peft_ff_modules=["linear1", "linear2"],
training_required_modules=["channel_projection"],
normalization=MinMaxScale(),
hub_repo="braindecode/braindecode-bendr",
)
Expand All @@ -94,20 +51,6 @@ def bendr(**overrides) -> PretrainedBackbone:
def cbramod(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.CBraMod",
model_kwargs=dict(
patch_size=200,
dim_feedforward=800,
n_layer=12,
nhead=8,
emb_dim=200,
drop_prob=0.1,
channels_kernel_stride_padding_norm=[
[25, 49, 25, 24, [5, 25]],
[25, 3, 1, 1, [5, 25]],
[25, 3, 1, 1, [5, 25]],
],
return_encoder_output=False,
),
# linear1, linear2 are the FFN layers in CrissCrossTransformerEncoderLayer
peft_ff_modules=["linear1", "linear2"],
normalization=DivideByConstant(factor=100.0),
Expand All @@ -119,29 +62,10 @@ def cbramod(**overrides) -> PretrainedBackbone:

def signal_jepa(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.SignalJEPA",
model_kwargs=dict(
feature_encoder__conv_layers_spec=[
(8, 32, 8),
(16, 2, 2),
(32, 2, 2),
(64, 2, 2),
(64, 2, 2),
],
feature_encoder__mode="default",
feature_encoder__conv_bias=False,
pos_encoder__spat_dim=30,
pos_encoder__time_dim=34,
pos_encoder__sfreq_features=1.0,
transformer__d_model=64,
transformer__num_encoder_layers=8,
transformer__num_decoder_layers=4,
transformer__nhead=8,
drop_prob=0.0,
),
model_cls="braindecode.models.InterpolatedSignalJEPA",
# linear1, linear2 are the FFN layers in TransformerEncoderLayer
peft_ff_modules=["linear1", "linear2"],
checkpoint_url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth",
hub_repo="braindecode/signal-jepa",
)
defaults.update(overrides)
return PretrainedBackbone(**defaults)
Expand All @@ -150,17 +74,6 @@ def signal_jepa(**overrides) -> PretrainedBackbone:
def reve(**overrides) -> PretrainedBackbone:
defaults = dict(
model_cls="braindecode.models.REVE",
model_kwargs=dict(
embed_dim=512,
depth=22,
heads=8,
head_dim=64,
mlp_dim_ratio=2.66,
use_geglu=True,
patch_size=200,
patch_overlap=20,
attention_pooling=False,
),
# net.1, net.3 are the two Linear layers inside FeedForward.net
peft_ff_modules=["net.1", "net.3"],
normalization=WindowZScore(clip_sigma=15.0),
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ Changelog = "https://github.com/braindecode/OpenEEGBench/blob/master/CHANGELOG.m
benchopt = ["benchopt>=1.9"]
dev = ["pytest", "pytest-cov"]

[tool.uv.sources]
braindecode = { git = "https://github.com/braindecode/braindecode.git" }

[tool.setuptools.packages.find]
where = ["."]
include = ["open_eeg_bench*"]
Expand Down
Loading
Loading