Skip to content

Improve inference speed by keeping model on GPU/MPS between fit/predict calls #523

@DanieleMorotti

Description

@DanieleMorotti

I tested the example code for the classifier:

from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier


# Load data
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# Initialize a classifier
clf = TabPFNClassifier(device="mps")
clf.fit(X_train, y_train)
print("Device: ", next(clf.executor_.model.parameters()).device)
# It prints "cpu"

# Predict probabilities
prediction_probabilities = clf.predict_proba(X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))
# Predict labels
predictions = clf.predict(X_test)
print("Accuracy", accuracy_score(y_test, predictions))

clf.executor_.model = clf.executor_.model.to("mps")
print("New device: ", next(clf.executor_.model.parameters()).device)
# Now it prints "mps:0"

When I print the device of the model parameters, they appear to be on the cpu, even though mps is available.
Is this the expected behavior, or is it unintended? The same issue occurs when I set the device to auto.

Debug info

PyTorch version: 2.8.0
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.6.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 3.31.5
Libc version: N/A

Python version: 3.12.1 (main, Jun 3 2024, 17:33:54) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-15.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Dependency Versions:

tabpfn: 2.2.1
torch: 2.8.0
numpy: 2.3.3
scipy: 1.16.2
pandas: 2.3.2
scikit-learn: 1.6.1
typing_extensions: 4.15.0
einops: 0.8.1
huggingface-hub: 0.35.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions