Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"scikit-learn>=1.2.0,<1.7",
"scipy>=1.11.1,<2",
"tabpfn>=2.0.6",
"tabpfn-common-utils @ git+https://github.com/PriorLabs/tabpfn_common_utils.git@d2ae05ec283187a5374ab16f260cacf119b06b7a",
]

requires-python = ">=3.9"
Expand Down
10 changes: 8 additions & 2 deletions src/tabpfn_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
__version__ = "0.1.0.dev0"

# Create alias for test_utils
from . import test_utils
from .embedding import TabPFNEmbedding
from .hpo import TunedTabPFNClassifier, TunedTabPFNRegressor
from .many_class import ManyClassClassifier
from .post_hoc_ensembles import AutoTabPFNClassifier, AutoTabPFNRegressor
from .unsupervised import TabPFNUnsupervisedModel

# Import utilities and wrapped TabPFN classes
from .utils import TabPFNClassifier, TabPFNRegressor, is_tabpfn
from .utils import (
TabPFNClassifier,
TabPFNRegressor,
is_tabpfn,
simulate_first,
test_utils,
)

__all__ = [
"test_utils",
Expand All @@ -28,4 +33,5 @@
"AutoTabPFNRegressor",
"TunedTabPFNClassifier",
"TunedTabPFNRegressor",
"simulate_first",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.validation import check_is_fitted

from tabpfn_extensions.misc.sklearn_compat import validate_data
from tabpfn_extensions.utils.simulator import simulate_first

from .pfn_phe import (
AutoPostHocEnsemblePredictor,
Expand Down Expand Up @@ -90,6 +91,7 @@ def __sklearn_tags__(self):
tags.estimator_type = "classifier"
return tags

@simulate_first
def fit(self, X, y, categorical_feature_indices: list[int] | None = None):
X, y = validate_data(
self,
Expand Down Expand Up @@ -268,6 +270,7 @@ def __sklearn_tags__(self):
tags.estimator_type = "regressor"
return tags

@simulate_first
def fit(self, X, y, categorical_feature_indices: list[int] | None = None):
# Validate input data

Expand Down
6 changes: 5 additions & 1 deletion src/tabpfn_extensions/sklearn_ensembles/meta_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
RandomForestTabPFNClassifier,
RandomForestTabPFNRegressor,
)
from tabpfn_extensions.utils import TabPFNClassifier, TabPFNRegressor, product_dict
from tabpfn_extensions.utils import (
TabPFNClassifier,
TabPFNRegressor,
product_dict,
)

from . import configs
from .weighted_ensemble import WeightedAverageEnsemble
Expand Down
34 changes: 34 additions & 0 deletions src/tabpfn_extensions/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Create alias for test_utils
from . import test_utils
from .simulator import simulate_first
from .utils import (
ClientTabPFNClassifier,
ClientTabPFNRegressor,
LocalTabPFNClassifier,
LocalTabPFNRegressor,
TabPFNClassifier,
TabPFNRegressor,
get_device,
get_tabpfn_models,
infer_categorical_features,
is_tabpfn,
product_dict,
softmax,
)

__all__ = [
"get_tabpfn_models",
"is_tabpfn",
"test_utils",
"simulate_first",
"get_device",
"TabPFNClassifier",
"TabPFNRegressor",
"LocalTabPFNClassifier",
"LocalTabPFNRegressor",
"ClientTabPFNClassifier",
"ClientTabPFNRegressor",
"infer_categorical_features",
"softmax",
"product_dict",
]
278 changes: 278 additions & 0 deletions src/tabpfn_extensions/utils/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
from __future__ import annotations

import contextlib
import functools
import logging
import os
import threading
import time
import warnings
from typing import Literal
from unittest import mock

import numpy as np
from tabpfn_common_utils.expense_estimation import estimate_duration

CLIENT_COST_ESTIMATION_LATENCY_OFFSET = 1.0


# Use thread-local variables to keep track of the current mode, simulated costs and time.
_thread_local = threading.local()


# Block of small helper functions to access and modify the thread-local
# variables used for mock prediction in a simple and unified way.
def get_is_local_tabpfn() -> bool:
return getattr(_thread_local, "use_local_tabpfn", True)


def set_is_local_tabpfn():
"""Figure out whether local TabPFN or client is used and set thread-local variable."""
use_local_env = os.getenv("USE_TABPFN_LOCAL", "true").lower() == "true"

try:
from tabpfn import TabPFNClassifier as LocalTabPFNClassifier
except ImportError:
LocalTabPFNClassifier = None

use_local = use_local_env and LocalTabPFNClassifier is not None
_thread_local.use_local_tabpfn = use_local


def get_mock_cost() -> float:
return getattr(_thread_local, "cost", 0.0)


def increment_mock_cost(value: float):
_thread_local.cost = get_mock_cost() + value


def set_mock_cost(value: float = 0.0):
_thread_local.cost = value


def get_mock_time() -> float:
return _thread_local.mock_time


def set_mock_time(value: float):
_thread_local.mock_time = value


def increment_mock_time(seconds: float):
set_mock_time(get_mock_time() + seconds)


# Block of functions that will replace the actual fit and predict functions in mock mode.
def mock_fit_local(self, X, y, config=None):
# Store train data, as it is needed for mocking prediction correctly. The client
# already does this internally.
self.X_train = X
self.y_train = y
return "mock_id"


def mock_fit_client(cls, X, y, config=None):
return "mock_id"


def mock_predict_local(self, X_test):
"""Wrapper for being able to distinguish between predict and predict_proba."""
return mock_predict_proba_local(self, X_test, from_classifier_predict=True)


def mock_predict_proba_local(self, X_test, from_classifier_predict=False):
"""Wrapper for mock_predict to set the correct arguments for local prediction. The client
already does this internally.
"""
task = (
"classification"
if self.__class__.__name__ == "TabPFNClassifier"
else "regression"
)
config = {"n_estimators": self.n_estimators}
params = {}
if task == "classification":
params["output_type"] = "preds" if from_classifier_predict else "probas"
else:
params["output_type"] = "mean"
return mock_predict(
self,
X_test,
task,
"dummy",
self.X_train,
self.y_train,
config,
params,
)


def mock_predict(
cls,
X_test,
task: Literal["classification", "regression"],
train_set_uid: str,
X_train,
y_train,
config=None,
predict_params=None,
):
"""Mock function for prediction, which can be called instead of the real
prediction function. Outputs random results in the expacted format and
keeps track of the simulated cost and time.
"""
if X_train is None or y_train is None:
raise ValueError(
"X_train and y_train must be provided in mock mode during prediction.",
)

duration = estimate_duration(
num_rows=X_train.shape[0] + X_test.shape[0],
num_features=X_test.shape[1],
task=task,
tabpfn_config=config,
latency_offset=0
if get_is_local_tabpfn()
else CLIENT_COST_ESTIMATION_LATENCY_OFFSET, # To slightly overestimate (safer)
)
increment_mock_time(duration)

cost = (
(X_train.shape[0] + X_test.shape[0])
* X_test.shape[1]
* config.get("n_estimators", 4 if task == "classification" else 8)
)
increment_mock_cost(cost)

# Return random result in the correct format
if task == "classification":
if (
not predict_params["output_type"]
or predict_params["output_type"] == "preds"
):
return np.random.rand(X_test.shape[0])
elif predict_params["output_type"] == "probas":
probs = np.random.rand(X_test.shape[0], len(np.unique(y_train)))
return probs / probs.sum(axis=1, keepdims=True)

elif task == "regression":
if not predict_params["output_type"] or predict_params["output_type"] == "mean":
return np.random.rand(X_test.shape[0])
elif predict_params["output_type"] == "full":
return {
"logits": np.random.rand(X_test.shape[0], 5000),
"mean": np.random.rand(X_test.shape[0]),
"median": np.random.rand(X_test.shape[0]),
"mode": np.random.rand(X_test.shape[0]),
"quantiles": np.random.rand(3, X_test.shape[0]),
"borders": np.random.rand(5001),
"ei": np.random.rand(X_test.shape[0]),
"pi": np.random.rand(X_test.shape[0]),
}
return None


@contextlib.contextmanager
def mock_mode():
"""Context manager that enables mock mode in the current thread."""
set_mock_cost(0.0)
start_time = time.time()
set_mock_time(start_time)

# Store original logging levels for all loggers
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
loggers.append(logging.getLogger())
original_levels = {logger: logger.level for logger in loggers}

# Overwrite actual fit and predict functions with mock functions
if get_is_local_tabpfn():
from tabpfn import TabPFNClassifier, TabPFNRegressor

original_fit_classification = TabPFNClassifier.fit
original_fit_regressor = TabPFNRegressor.fit
original_predict_classification = TabPFNClassifier.predict
original_predict_proba_classification = TabPFNClassifier.predict_proba
original_predict_regressor = TabPFNRegressor.predict
TabPFNClassifier.fit = mock_fit_local
TabPFNClassifier.predict = mock_predict_local
TabPFNClassifier.predict_proba = mock_predict_proba_local
TabPFNRegressor.fit = mock_fit_local
TabPFNRegressor.predict = mock_predict_proba_local
else:
from tabpfn_client.service_wrapper import InferenceClient

original_fit = InferenceClient.fit
original_predict = InferenceClient.predict
InferenceClient.fit = classmethod(mock_fit_client)
InferenceClient.predict = classmethod(mock_predict)

# Suppress all warnings and logging
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Set all loggers to ERROR level
for logger in loggers:
logger.setLevel(logging.ERROR)

with mock.patch("time.time", side_effect=get_mock_time):
try:
yield lambda: (get_mock_time() - start_time, get_mock_cost())
finally:
if get_is_local_tabpfn():
from tabpfn.classifier import TabPFNClassifier
from tabpfn.regressor import TabPFNRegressor

TabPFNClassifier.fit = original_fit_classification
TabPFNClassifier.predict = original_predict_classification
TabPFNClassifier.predict_proba = (
original_predict_proba_classification
)
TabPFNRegressor.fit = original_fit_regressor
TabPFNRegressor.predict = original_predict_regressor
else:
from tabpfn_client.service_wrapper import InferenceClient

InferenceClient.fit = original_fit
InferenceClient.predict = original_predict

# Restore original logging levels
for logger in loggers:
logger.setLevel(original_levels[logger])


def simulate_first(func):
"""Decorator that first runs the decorated function in mock mode to simulate its duration
and credit usage. If client is used, only executes function if enough credits are available.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
set_is_local_tabpfn()
with mock_mode() as get_simulation_results:
func(*args, **kwargs)
Comment on lines +244 to +252
Copy link

Copilot AI May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simulate_first decorator executes the decorated function twice (once in simulation mode and again for the actual call). This could lead to unintended side effects if the function is not idempotent; consider either ensuring that the function is side-effect free during simulation or refactoring to avoid duplicate execution.

Suggested change
"""Decorator that first runs the decorated function in mock mode to simulate its duration
and credit usage. If client is used, only executes function if enough credits are available.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
set_is_local_tabpfn()
with mock_mode() as get_simulation_results:
func(*args, **kwargs)
"""Decorator that estimates the duration and credit usage of the decorated function
without executing it during the simulation phase. If client is used, only executes
the function if enough credits are available.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
set_is_local_tabpfn()
with mock_mode() as get_simulation_results:
# Simulate resource usage without executing the function

Copilot uses AI. Check for mistakes.
time_estimate, credit_estimate = get_simulation_results()

if not get_is_local_tabpfn():
from tabpfn_client import get_access_token
from tabpfn_client.client import ServiceClient

access_token = get_access_token()
api_usage = ServiceClient.get_api_usage(access_token)

if (
not api_usage["usage_limit"] == -1
and api_usage["usage_limit"] - api_usage["current_usage"]
< credit_estimate
):
raise RuntimeError(
f"Not enough credits left. Estimated credit usage: {credit_estimate}, credits left: {api_usage['usage_limit'] - api_usage['current_usage']}",
)
print("Enough credits left.") # noqa: T201

print( # noqa: T201
f"Estimated duration: {time_estimate:.1f} seconds {'(on GPU)' if get_is_local_tabpfn() else ''}",
)

return func(*args, **kwargs)

return wrapper
Loading