Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
505efde
Push to test
Feb 2, 2026
a0e5b38
Fix merge issue
sophie-xhonneux Feb 2, 2026
7e95bef
Claude fixing things
Feb 2, 2026
2e1bd76
Fixing Betas expected everywhere
Feb 2, 2026
84fa7d1
First commit
Feb 3, 2026
c98c746
use existing implementation
Feb 3, 2026
c2495b5
Add Layerscale etc to default config
sophie-xhonneux Feb 3, 2026
68aa2f4
Make JEPA default config for testing
sophie-xhonneux Feb 3, 2026
c0ce9dd
Add assert to prevent silent errors
Feb 4, 2026
aebe434
Merge branch 'develop' into sophiex/dev/test-layerscale-etc
Feb 4, 2026
71d2cce
Add collapse monitoring
Feb 4, 2026
1d29611
Fix bug
Feb 4, 2026
bc92ae7
Fix SVD computation failing
Feb 4, 2026
7693c19
Reduce variables logged
Feb 4, 2026
7f8de00
Fix EMA beta value computation
Feb 4, 2026
c3eb019
Refactor get_current_beta to ema.py
Feb 4, 2026
59a0a89
Sensible default for ema in jepa
sophie-xhonneux Feb 4, 2026
505331c
Merge branch 'sophiex/dev/monitor-collapse' into sophiex/dev/test-lay…
sophie-xhonneux Feb 4, 2026
4a091c8
New defaults
sophie-xhonneux Feb 5, 2026
32d951b
Implement Frozenteacher
Feb 6, 2026
3298252
Test config
sophie-xhonneux Feb 6, 2026
b4c46b1
Refactor frozen teacher creation
Feb 6, 2026
590d366
Fix stuff
Feb 6, 2026
64ae9f1
Fix
Feb 6, 2026
4444b04
Debug more
Feb 6, 2026
c3e52d0
Enable frozen models not trained with SSL
Feb 6, 2026
211f477
Improve code quality
Feb 6, 2026
491a69d
Test config
sophie-xhonneux Feb 6, 2026
08dbf6f
Update jepa config
sophie-xhonneux Feb 6, 2026
1133018
Try SALT training
sophie-xhonneux Feb 7, 2026
11b2e60
Fix model path loading
Feb 16, 2026
e08b7f8
Fix model_path
sophie-xhonneux Feb 16, 2026
b42a778
Fix inference corner case (#1818)
clessig Feb 6, 2026
8521b8c
fix latent_loss check in mode handling (#1784)
TillHae Feb 6, 2026
57c8518
Streamline run_train.py so it is suitable to be run both as a script …
grassesi Feb 6, 2026
424e188
Sgrasse/develop/435 unify dataset access (#1757)
grassesi Feb 6, 2026
141315b
Fix plot_train (#1831)
clessig Feb 13, 2026
b5b4a44
nse_metric (#1833)
jesicapinon Feb 16, 2026
f0d4a06
Random encoder fientuning on NPPATMS
sophie-xhonneux Feb 17, 2026
66394b5
Update configs to fix leak
sophie-xhonneux Feb 17, 2026
0884ae6
Plotting config
sophie-xhonneux Feb 18, 2026
3100aa9
Add IASI config
sophie-xhonneux Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions config/config_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 2
ae_global_num_blocks: 0
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
Expand All @@ -37,7 +37,7 @@ ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 8
ae_aggregation_num_blocks: 12
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
Expand Down Expand Up @@ -130,10 +130,33 @@ data_loading :

# config for training
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["student_teacher"]

# Collapse monitoring for SSL training (JEPA/DINO/iBOT)
# Detects representation collapse via various metrics
collapse_monitoring:
enabled: true
compute_frequency: 100 # batches between metric computations
log_frequency: 100 # batches between metric logging
metrics:
effective_rank:
enabled: true
tensor_source: "both" # "student", "teacher", or "both"
sample_size: 2048 # max samples for SVD (0 = no sampling)
singular_values:
enabled: true
tensor_source: "both"
sample_size: 2048
dimension_variance:
enabled: true
tensor_source: "both" # cheap to compute, good early indicator
prototype_entropy:
enabled: true # only applies to DINO
ema_beta:
enabled: true

num_mini_epochs: 32
samples_per_mini_epoch: 4096
shuffle: True
Expand All @@ -148,25 +171,36 @@ training_config:

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 5e-5
lr_max: 1e-4
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 512
num_steps_warmup: 4096
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
# Optimizer type: "adamw" (default) or "muon_adamw" (Muon for hidden weights, AdamW for embeddings/heads)
type: "muon_adamw"
grad_clip: 0.1
weight_decay: 0.05
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08
muon:
# Learning rate multiplier for Muon relative to base LR (muon_lr = base_lr * lr_multiplier)
lr_multiplier: 30.0
# Momentum factor for Muon SGD
momentum: 0.95
# Use Nesterov momentum
nesterov: true
# Weight decay for Muon parameters (uses optimizer.weight_decay if not specified)
weight_decay: 0.05

losses : {
"student-teacher": {
Expand All @@ -179,16 +213,20 @@ training_config:
"num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768,
"dropout_rate": 0.1,
target_source_correspondence: {0 : {0 : "subset"} },
},
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
model_param_overrides : {
training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
},
}
}
target_and_aux_calc: {FrozenTeacher: {
teacher_run_id: "yoqxf234", # "zosrc8ti", # Required
teacher_mini_epoch: -1}},
# },
# target_and_aux_calc: { "EMATeacher" :
# { ema_ramp_up_ratio : null,
# ema_halflife_in_thousands: 1e-1,
# model_param_overrides : {
# training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
# },
# }
# }
}
}

Expand Down
90 changes: 6 additions & 84 deletions config/config_jepa_finetuning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,8 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

ae_local_dim_embed: 1024
ae_local_num_blocks: 4
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 2
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 8
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
fe_num_blocks: 6
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0
with_step_conditioning: True # False

healpix_level: 5

# performance
with_mixed_precision: True
Expand All @@ -78,11 +21,6 @@ mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

freeze_modules: ".*encoder.*|.*latent_pre_norm.*|.*latent_heads.*"

Expand All @@ -92,7 +30,9 @@ zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore
#####################################

# streams_directory: "./config/streams/era5_1deg/"
streams_directory: "./config/streams/era5_synop_finetuning/"
# streams_directory: "./config/streams/era5_synop_finetuning/"
# streams_directory: "./config/streams/era5_nppatms_finetuning/"
streams_directory: "./config/streams/era5_iasi_finetuning/"
streams: ???

general:
Expand Down Expand Up @@ -139,8 +79,8 @@ training_config:
samples_per_mini_epoch: 4096
shuffle: True

start_date: 1979-01-01T00:00
end_date: 2022-12-31T00:00
start_date: 2012-01-01T00:00
end_date: 2021-12-31T00:00

time_window_step: 06:00:00
time_window_len: 06:00:00
Expand Down Expand Up @@ -172,24 +112,6 @@ training_config:
losses : {
"student-teacher": {
enabled: False,
type: LossLatentSSLStudentTeacher,
weight: 1.0,
loss_fcts : {
"JEPA": {
'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
"num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768,
"dropout_rate": 0.1,
target_source_correspondence: {0 : {0 : "subset"} },
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
model_param_overrides : {
training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
},
}
}
},
"physical": {
type: LossPhysical,
Expand Down Expand Up @@ -271,7 +193,7 @@ validation_config:
# write samples in normalized model space
normalized_samples: False,
# output streams to write; default all
streams: ["SurfaceCombined"],
streams: ["METOPBIASI"],
}

# run validation before training starts (mainly for model development)
Expand Down
Loading