-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
Conversation
…ameter loading Fixes Lightning-AI#21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a public adapt_checkpoint_hparams() hook to LightningCLI that enables users to customize hyperparameters loaded from checkpoints before model instantiation. This addresses the issue of loading checkpoints across different module classes (e.g., from TrainingModule to InferenceModule) where incompatible __init__ parameters would otherwise cause failures.
Key Changes:
- Added
adapt_checkpoint_hparams()public method with comprehensive documentation - Integrated the hook into
_parse_ckpt_path()to allow customization before hyperparameter application - Maintained backward compatibility with a default no-op implementation
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Adapt checkpoint hyperparameters before instantiating the model class. | ||
| This method allows for customization of hyperparameters loaded from a checkpoint when | ||
| using a different model class than the one used for training. For example, when loading | ||
| a checkpoint from a TrainingModule to use with an InferenceModule that has different | ||
| ``__init__`` parameters, you can remove or modify incompatible hyperparameters. | ||
| Args: | ||
| checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. | ||
| Returns: | ||
| Dictionary of adapted hyperparameters to be used for model instantiation. | ||
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| # Remove training-specific hyperparameters not needed for inference | ||
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) | ||
| return checkpoint_hparams | ||
| Note: | ||
| If subclass module mode is enabled and ``_class_path`` is present in the checkpoint | ||
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new adapt_checkpoint_hparams() hook lacks test coverage. Given that tests/tests_pytorch/test_cli.py contains comprehensive tests for checkpoint loading functionality (e.g., test_lightning_cli_ckpt_path_argument_hparams and test_lightning_cli_ckpt_path_argument_hparams_subclass_mode), tests should be added to verify:
- The hook is called when loading checkpoint hyperparameters
- Modifications made in the hook are applied correctly
- Returning an empty dict properly skips checkpoint hyperparameter loading
- The hook works in both regular and subclass modes
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
What does this PR do?
Fixes #21255
This PR adds a public
adapt_checkpoint_hparams()hook toLightningCLIthat allows users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This solves the problem of loading checkpoints across different module classes (e.g., fromTrainingModuletoInferenceModule).Problem
When using
LightningCLIwith checkpoints, hyperparameters saved during training are automatically loaded and applied when running other subcommands (test, predict, etc.). This is convenient when using the same module class, but fails when using a different class with incompatible__init__parameters.Example scenario:
Running
cli predict --ckpt_path checkpoint.ckptwithInferenceModulefails because the CLI tries to passlr=1e-3from the checkpoint toInferenceModule.__init__().Solution
Added
adapt_checkpoint_hparams()public method that users can override to customize loaded hyperparameters:Implementation Details
adapt_checkpoint_hparams()public method inLightningCLI_parse_ckpt_path()to call the hook after loading but before applying hyperparametersWhy This Approach?
As discussed in #21255, this is superior to alternatives:
hidden_dim)Testing
The implementation:
_class_pathmodification when neededExample Use Cases
Remove training-only parameters:
Change module class in subclass mode:
Disable all checkpoint hyperparameters:
Does your PR introduce any breaking changes?
No, this is a purely additive change. The default implementation returns hyperparameters unchanged, preserving existing behavior.
Before submitting
PR review
cc: @mauvilsa @ziw-liu
📚 Documentation preview 📚: https://pytorch-lightning--21408.org.readthedocs.build/en/21408/