Skip to content

Sophiex/dev/pretrained frozen teacher#1824

Open
sophie-xhonneux wants to merge 30 commits intodevelopfrom
sophiex/dev/pretrained-frozen-teacher
Open

Sophiex/dev/pretrained frozen teacher#1824
sophie-xhonneux wants to merge 30 commits intodevelopfrom
sophiex/dev/pretrained-frozen-teacher

Conversation

@sophie-xhonneux
Copy link
Contributor

Description

The goal is to train against a frozen pre-trained teacher (e.g. by MAE)

Issue Number

#1815

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Sophie Xhonneux and others added 27 commits February 2, 2026 21:53
Fix
The issue is we're passing cf.training_config (current training config) but the teacher model's latent
  heads are defined by teacher_config. We need to pass the teacher's training config so the postprocessing
  keys match the teacher model's outputs.
The fix now:
  1. FrozenTeacher inspects the teacher model's actual latent_heads attribute to determine what
  postprocessing is needed
  2. Sets up JEPA/DINO/iBOT postprocessing based on what heads exist (using identity transform for all,
  with warnings for DINO/iBOT since full centering isn't supported for frozen teachers)
  3. Tests updated to use models with latent_heads attributes
Summary of Changes

  Key insight from your feedback: The frozen teacher may have been pre-trained with any method
  (forecasting, MAE, etc.) and doesn't need to have SSL latent heads. We should:
  1. Use the student's training config to know which SSL losses are needed
  2. Add identity heads (LatentPredictionHeadIdentity) to the teacher if they don't exist
  3. Use identity postprocessing (JEPATargetProcessing) for all SSL losses

  Changes Made

  src/weathergen/train/target_and_aux_ssl_teacher.py:
  - Added import for LatentPredictionHeadIdentity
  - Rewrote FrozenTeacher.__init__ to:
    - Accept training_cfg (the student's config) to determine required SSL heads
    - Call _get_required_ssl_heads() to extract loss names from config
    - Call _ensure_identity_heads() to add missing heads to the teacher model
    - Set up identity postprocessing for all SSL losses
  - Added _get_required_ssl_heads(): extracts SSL loss names from training config, defaults to {"JEPA"} if
  none found
  - Added _ensure_identity_heads(): adds LatentPredictionHeadIdentity for any missing heads
  - Updated from_pretrained() to pass cf.training_config to constructor

  tests/test_encoder_teacher.py:
  - Added model_without_latent_heads fixture (simulates a forecasting-only teacher)
  - Added 5 new tests:
    - test_frozen_teacher_adds_identity_heads_when_missing
    - test_frozen_teacher_uses_training_cfg_for_heads
    - test_frozen_teacher_defaults_to_jepa_without_config
    - test_frozen_teacher_preserves_existing_heads
    - test_frozen_teacher_all_postprocessing_is_identity
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant