fix: support save/load for fit_mode='fit_with_cache' estimators#977
fix: support save/load for fit_mode='fit_with_cache' estimators#977devangpratap wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request enables saving and loading for models using the fit_with_cache mode by implementing weight-stripping logic during serialization and re-injection during loading. Key changes include the removal of previous restrictions on KV cache persistence and the addition of comprehensive cross-device tests. Feedback suggests explicitly moving models to the CPU in _set_models to prevent unexpected GPU memory usage and optimizing _create_copy_for_pickling to avoid memory spikes by detaching weights before performing a deep copy.
| def _set_models(self, models: list[Architecture]) -> None: | ||
| # self.models contains weight-stripped shells from the pickle. | ||
| # Inject fresh weights from the base models. | ||
| for model, ensemble_member in zip(self.models, self.ensemble_members): | ||
| fresh = models[ensemble_member.config._model_index] | ||
| fresh_params = dict(fresh.named_parameters()) | ||
| for name, param in model.named_parameters(): | ||
| if name not in fresh_params: | ||
| raise RuntimeError( | ||
| f"Parameter '{name}' not found in the base model. " | ||
| "The saved fit state is incompatible with the current " | ||
| "model checkpoint." | ||
| ) | ||
| param.data = fresh_params[name].data.clone() | ||
| if self.force_inference_dtype is not None: | ||
| model.type(self.force_inference_dtype) |
There was a problem hiding this comment.
In _set_models, the parameters are cloned from the base models (fresh). If the base models are on a GPU, the ensemble member models will now have their parameters on that GPU. This violates the assumption stated in _move_models_to_devices (line 1014) that models are kept on CPU between predictions.
If multiple estimators are loaded, this could lead to a GPU Out-Of-Memory (OOM) error immediately after loading. You should explicitly move each model to the CPU after restoring its parameters.
| def _set_models(self, models: list[Architecture]) -> None: | |
| # self.models contains weight-stripped shells from the pickle. | |
| # Inject fresh weights from the base models. | |
| for model, ensemble_member in zip(self.models, self.ensemble_members): | |
| fresh = models[ensemble_member.config._model_index] | |
| fresh_params = dict(fresh.named_parameters()) | |
| for name, param in model.named_parameters(): | |
| if name not in fresh_params: | |
| raise RuntimeError( | |
| f"Parameter '{name}' not found in the base model. " | |
| "The saved fit state is incompatible with the current " | |
| "model checkpoint." | |
| ) | |
| param.data = fresh_params[name].data.clone() | |
| if self.force_inference_dtype is not None: | |
| model.type(self.force_inference_dtype) | |
| @override | |
| def _set_models(self, models: list[Architecture]) -> None: | |
| # self.models contains weight-stripped shells from the pickle. | |
| # Inject fresh weights from the base models. | |
| for model, ensemble_member in zip(self.models, self.ensemble_members): | |
| fresh = models[ensemble_member.config._model_index] | |
| fresh_params = dict(fresh.named_parameters()) | |
| for name, param in model.named_parameters(): | |
| if name not in fresh_params: | |
| raise RuntimeError( | |
| f"Parameter '{name}' not found in the base model. " | |
| "The saved fit state is incompatible with the current " | |
| "model checkpoint." | |
| ) | |
| param.data = fresh_params[name].data.clone() | |
| if self.force_inference_dtype is not None: | |
| model.type(self.force_inference_dtype) | |
| model.cpu() |
| state_copy = deepcopy(self) | ||
| # Strip nn.Parameter data from models to avoid saving large weights, |
There was a problem hiding this comment.
The current implementation of _create_copy_for_pickling performs a deepcopy(self) before stripping the model weights. This will cause a significant memory spike as it duplicates the large foundation model weights in RAM (or GPU RAM) before they are discarded.
To avoid this, consider temporarily stripping the weights from self.models before the deepcopy and restoring them afterwards, similar to the approach used in InferenceEngineExplicitKVCache (lines 1340-1351).
Summary
Resolves #412.
save_fit_state/load_from_fit_statepreviously raisedNotImplementedErrorfor estimators fitted withfit_mode="fit_with_cache". This patch adds serialization support for both KV cache inference engines (InferenceEngineCacheKVandInferenceEngineExplicitKVCache).Approach
InferenceEngineCacheKV(v2.5 models):Each ensemble member holds a deepcopied model with internal KV cache state stored as registered buffers, plus encoder fitted state stored as non-persistent buffers and regular attributes. These all survive
deepcopy/pickle, butnn.Parameterweights do not need to be serialized (they come from the base checkpoint on load)._create_copy_for_pickling: deepcopy the engine, then zero out allnn.Parameter.datato strip the large weights while keeping model structure intact._set_models: on load, inject fresh weights from the base model into the weight-stripped shells. RaisesRuntimeErrorif any parameter name is missing (incompatible checkpoint)._move_models_to_devices: replaces the oldNotImplementedErrorwith a device assignment.iter_outputs()already handles per-predict CPU-to-device transfers.InferenceEngineExplicitKVCache(v3 models):KV caches are stored externally (not inside the model), so the model itself is stateless.
_create_copy_for_pickling: temporarily detachesmodel_cachesandkv_cachesfromselfbeforedeepcopyto avoid a memory spike from copying GPU tensors that get discarded. Attaches CPU copies of the KV caches to the serialized copy._set_modelsand_move_models_to_devices: inherited fromMultiDeviceInferenceEngine, work without changes.Changes
src/tabpfn/inference.py: removed_raise_if_kv_cache_enabled_on_save_or_loadand its two call sites. Added override methods on both cache engines.src/tabpfn/model_loading.py: removed "not supported" docstring line fromsave_fitted_tabpfn_model.tests/test_save_load_fitted_model.py: added two parametrized tests (single and double round-trip) covering regression + classification across device pairs.Test plan
test__save_and_load_fit_with_cache__predictions_equal(regression + classification x device pairs)test__save_and_load_fit_with_cache_twice__predictions_equal(double save/load cycle, same matrix)