Skip to content

fix: support save/load for fit_mode='fit_with_cache' estimators#977

Open
devangpratap wants to merge 3 commits into
PriorLabs:mainfrom
devangpratap:fix/save-load-fit-with-cache
Open

fix: support save/load for fit_mode='fit_with_cache' estimators#977
devangpratap wants to merge 3 commits into
PriorLabs:mainfrom
devangpratap:fix/save-load-fit-with-cache

Conversation

@devangpratap
Copy link
Copy Markdown

Summary

Resolves #412.

save_fit_state / load_from_fit_state previously raised NotImplementedError for estimators fitted with fit_mode="fit_with_cache". This patch adds serialization support for both KV cache inference engines (InferenceEngineCacheKV and InferenceEngineExplicitKVCache).

Approach

InferenceEngineCacheKV (v2.5 models):

Each ensemble member holds a deepcopied model with internal KV cache state stored as registered buffers, plus encoder fitted state stored as non-persistent buffers and regular attributes. These all survive deepcopy/pickle, but nn.Parameter weights do not need to be serialized (they come from the base checkpoint on load).

  • _create_copy_for_pickling: deepcopy the engine, then zero out all nn.Parameter.data to strip the large weights while keeping model structure intact.
  • _set_models: on load, inject fresh weights from the base model into the weight-stripped shells. Raises RuntimeError if any parameter name is missing (incompatible checkpoint).
  • _move_models_to_devices: replaces the old NotImplementedError with a device assignment. iter_outputs() already handles per-predict CPU-to-device transfers.

InferenceEngineExplicitKVCache (v3 models):

KV caches are stored externally (not inside the model), so the model itself is stateless.

  • _create_copy_for_pickling: temporarily detaches model_caches and kv_caches from self before deepcopy to avoid a memory spike from copying GPU tensors that get discarded. Attaches CPU copies of the KV caches to the serialized copy.
  • _set_models and _move_models_to_devices: inherited from MultiDeviceInferenceEngine, work without changes.

Changes

  • src/tabpfn/inference.py: removed _raise_if_kv_cache_enabled_on_save_or_load and its two call sites. Added override methods on both cache engines.
  • src/tabpfn/model_loading.py: removed "not supported" docstring line from save_fitted_tabpfn_model.
  • tests/test_save_load_fitted_model.py: added two parametrized tests (single and double round-trip) covering regression + classification across device pairs.

Test plan

  • test__save_and_load_fit_with_cache__predictions_equal (regression + classification x device pairs)
  • test__save_and_load_fit_with_cache_twice__predictions_equal (double save/load cycle, same matrix)
  • Existing save/load tests still pass (no regressions)

@devangpratap devangpratap requested a review from a team as a code owner May 20, 2026 15:49
@devangpratap devangpratap requested review from klemens-floege and removed request for a team May 20, 2026 15:49
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 20, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables saving and loading for models using the fit_with_cache mode by implementing weight-stripping logic during serialization and re-injection during loading. Key changes include the removal of previous restrictions on KV cache persistence and the addition of comprehensive cross-device tests. Feedback suggests explicitly moving models to the CPU in _set_models to prevent unexpected GPU memory usage and optimizing _create_copy_for_pickling to avoid memory spikes by detaching weights before performing a deep copy.

Comment thread src/tabpfn/inference.py
Comment on lines +994 to +1009
def _set_models(self, models: list[Architecture]) -> None:
# self.models contains weight-stripped shells from the pickle.
# Inject fresh weights from the base models.
for model, ensemble_member in zip(self.models, self.ensemble_members):
fresh = models[ensemble_member.config._model_index]
fresh_params = dict(fresh.named_parameters())
for name, param in model.named_parameters():
if name not in fresh_params:
raise RuntimeError(
f"Parameter '{name}' not found in the base model. "
"The saved fit state is incompatible with the current "
"model checkpoint."
)
param.data = fresh_params[name].data.clone()
if self.force_inference_dtype is not None:
model.type(self.force_inference_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _set_models, the parameters are cloned from the base models (fresh). If the base models are on a GPU, the ensemble member models will now have their parameters on that GPU. This violates the assumption stated in _move_models_to_devices (line 1014) that models are kept on CPU between predictions.

If multiple estimators are loaded, this could lead to a GPU Out-Of-Memory (OOM) error immediately after loading. You should explicitly move each model to the CPU after restoring its parameters.

Suggested change
def _set_models(self, models: list[Architecture]) -> None:
# self.models contains weight-stripped shells from the pickle.
# Inject fresh weights from the base models.
for model, ensemble_member in zip(self.models, self.ensemble_members):
fresh = models[ensemble_member.config._model_index]
fresh_params = dict(fresh.named_parameters())
for name, param in model.named_parameters():
if name not in fresh_params:
raise RuntimeError(
f"Parameter '{name}' not found in the base model. "
"The saved fit state is incompatible with the current "
"model checkpoint."
)
param.data = fresh_params[name].data.clone()
if self.force_inference_dtype is not None:
model.type(self.force_inference_dtype)
@override
def _set_models(self, models: list[Architecture]) -> None:
# self.models contains weight-stripped shells from the pickle.
# Inject fresh weights from the base models.
for model, ensemble_member in zip(self.models, self.ensemble_members):
fresh = models[ensemble_member.config._model_index]
fresh_params = dict(fresh.named_parameters())
for name, param in model.named_parameters():
if name not in fresh_params:
raise RuntimeError(
f"Parameter '{name}' not found in the base model. "
"The saved fit state is incompatible with the current "
"model checkpoint."
)
param.data = fresh_params[name].data.clone()
if self.force_inference_dtype is not None:
model.type(self.force_inference_dtype)
model.cpu()

Comment thread src/tabpfn/inference.py Outdated
Comment on lines +984 to +985
state_copy = deepcopy(self)
# Strip nn.Parameter data from models to avoid saving large weights,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _create_copy_for_pickling performs a deepcopy(self) before stripping the model weights. This will cause a significant memory spike as it duplicates the large foundation model weights in RAM (or GPU RAM) before they are discarded.

To avoid this, consider temporarily stripping the weights from self.models before the deepcopy and restoring them afterwards, similar to the approach used in InferenceEngineExplicitKVCache (lines 1340-1351).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support save_fitted_tabpfn_model/load_fitted_tabpfn_model with fit_mode 'fit_with_cache'

2 participants