-
Notifications
You must be signed in to change notification settings - Fork 56
Add extension simulation #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e46db79
Add simulation setup to estimate credit usage for PHE with client
davidotte dd1df47
Add support for local TabPFN
davidotte a059170
Add submodule tabpfn_common_utils
davidotte 941b7c1
Fix local TabPFN simulation
davidotte 5fc38c1
Remove tabpfn_common_utils submodule
davidotte 16b0b6e
Add small improvements
davidotte c8d90c2
Fix ruff errors
davidotte 576ae19
Fix further ruff errors
davidotte 57ce2a3
Fix utils imports
davidotte 096fe26
Fix small errors
davidotte 86e71bb
Fix noqa directive
davidotte b382225
Update tabpfn_common_utils reference
davidotte File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Create alias for test_utils | ||
| from . import test_utils | ||
| from .simulator import simulate_first | ||
| from .utils import ( | ||
| ClientTabPFNClassifier, | ||
| ClientTabPFNRegressor, | ||
| LocalTabPFNClassifier, | ||
| LocalTabPFNRegressor, | ||
| TabPFNClassifier, | ||
| TabPFNRegressor, | ||
| get_device, | ||
| get_tabpfn_models, | ||
| infer_categorical_features, | ||
| is_tabpfn, | ||
| product_dict, | ||
| softmax, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "get_tabpfn_models", | ||
| "is_tabpfn", | ||
| "test_utils", | ||
| "simulate_first", | ||
| "get_device", | ||
| "TabPFNClassifier", | ||
| "TabPFNRegressor", | ||
| "LocalTabPFNClassifier", | ||
| "LocalTabPFNRegressor", | ||
| "ClientTabPFNClassifier", | ||
| "ClientTabPFNRegressor", | ||
| "infer_categorical_features", | ||
| "softmax", | ||
| "product_dict", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,278 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import contextlib | ||
| import functools | ||
| import logging | ||
| import os | ||
| import threading | ||
| import time | ||
| import warnings | ||
| from typing import Literal | ||
| from unittest import mock | ||
|
|
||
| import numpy as np | ||
| from tabpfn_common_utils.expense_estimation import estimate_duration | ||
|
|
||
| CLIENT_COST_ESTIMATION_LATENCY_OFFSET = 1.0 | ||
|
|
||
|
|
||
| # Use thread-local variables to keep track of the current mode, simulated costs and time. | ||
| _thread_local = threading.local() | ||
|
|
||
|
|
||
| # Block of small helper functions to access and modify the thread-local | ||
| # variables used for mock prediction in a simple and unified way. | ||
| def get_is_local_tabpfn() -> bool: | ||
| return getattr(_thread_local, "use_local_tabpfn", True) | ||
|
|
||
|
|
||
| def set_is_local_tabpfn(): | ||
| """Figure out whether local TabPFN or client is used and set thread-local variable.""" | ||
| use_local_env = os.getenv("USE_TABPFN_LOCAL", "true").lower() == "true" | ||
|
|
||
| try: | ||
| from tabpfn import TabPFNClassifier as LocalTabPFNClassifier | ||
| except ImportError: | ||
| LocalTabPFNClassifier = None | ||
|
|
||
| use_local = use_local_env and LocalTabPFNClassifier is not None | ||
| _thread_local.use_local_tabpfn = use_local | ||
|
|
||
|
|
||
| def get_mock_cost() -> float: | ||
| return getattr(_thread_local, "cost", 0.0) | ||
|
|
||
|
|
||
| def increment_mock_cost(value: float): | ||
| _thread_local.cost = get_mock_cost() + value | ||
|
|
||
|
|
||
| def set_mock_cost(value: float = 0.0): | ||
| _thread_local.cost = value | ||
|
|
||
|
|
||
| def get_mock_time() -> float: | ||
| return _thread_local.mock_time | ||
|
|
||
|
|
||
| def set_mock_time(value: float): | ||
| _thread_local.mock_time = value | ||
|
|
||
|
|
||
| def increment_mock_time(seconds: float): | ||
| set_mock_time(get_mock_time() + seconds) | ||
|
|
||
|
|
||
| # Block of functions that will replace the actual fit and predict functions in mock mode. | ||
| def mock_fit_local(self, X, y, config=None): | ||
| # Store train data, as it is needed for mocking prediction correctly. The client | ||
| # already does this internally. | ||
| self.X_train = X | ||
| self.y_train = y | ||
| return "mock_id" | ||
|
|
||
|
|
||
| def mock_fit_client(cls, X, y, config=None): | ||
| return "mock_id" | ||
|
|
||
|
|
||
| def mock_predict_local(self, X_test): | ||
| """Wrapper for being able to distinguish between predict and predict_proba.""" | ||
| return mock_predict_proba_local(self, X_test, from_classifier_predict=True) | ||
|
|
||
|
|
||
| def mock_predict_proba_local(self, X_test, from_classifier_predict=False): | ||
| """Wrapper for mock_predict to set the correct arguments for local prediction. The client | ||
| already does this internally. | ||
| """ | ||
| task = ( | ||
| "classification" | ||
| if self.__class__.__name__ == "TabPFNClassifier" | ||
| else "regression" | ||
| ) | ||
| config = {"n_estimators": self.n_estimators} | ||
| params = {} | ||
| if task == "classification": | ||
| params["output_type"] = "preds" if from_classifier_predict else "probas" | ||
| else: | ||
| params["output_type"] = "mean" | ||
| return mock_predict( | ||
| self, | ||
| X_test, | ||
| task, | ||
| "dummy", | ||
| self.X_train, | ||
| self.y_train, | ||
| config, | ||
| params, | ||
| ) | ||
|
|
||
|
|
||
| def mock_predict( | ||
| cls, | ||
| X_test, | ||
| task: Literal["classification", "regression"], | ||
| train_set_uid: str, | ||
| X_train, | ||
| y_train, | ||
| config=None, | ||
| predict_params=None, | ||
| ): | ||
| """Mock function for prediction, which can be called instead of the real | ||
| prediction function. Outputs random results in the expacted format and | ||
| keeps track of the simulated cost and time. | ||
| """ | ||
| if X_train is None or y_train is None: | ||
| raise ValueError( | ||
| "X_train and y_train must be provided in mock mode during prediction.", | ||
| ) | ||
|
|
||
| duration = estimate_duration( | ||
| num_rows=X_train.shape[0] + X_test.shape[0], | ||
| num_features=X_test.shape[1], | ||
| task=task, | ||
| tabpfn_config=config, | ||
| latency_offset=0 | ||
| if get_is_local_tabpfn() | ||
| else CLIENT_COST_ESTIMATION_LATENCY_OFFSET, # To slightly overestimate (safer) | ||
| ) | ||
| increment_mock_time(duration) | ||
|
|
||
| cost = ( | ||
| (X_train.shape[0] + X_test.shape[0]) | ||
| * X_test.shape[1] | ||
| * config.get("n_estimators", 4 if task == "classification" else 8) | ||
| ) | ||
| increment_mock_cost(cost) | ||
|
|
||
| # Return random result in the correct format | ||
| if task == "classification": | ||
| if ( | ||
| not predict_params["output_type"] | ||
| or predict_params["output_type"] == "preds" | ||
| ): | ||
| return np.random.rand(X_test.shape[0]) | ||
| elif predict_params["output_type"] == "probas": | ||
| probs = np.random.rand(X_test.shape[0], len(np.unique(y_train))) | ||
| return probs / probs.sum(axis=1, keepdims=True) | ||
|
|
||
| elif task == "regression": | ||
| if not predict_params["output_type"] or predict_params["output_type"] == "mean": | ||
| return np.random.rand(X_test.shape[0]) | ||
| elif predict_params["output_type"] == "full": | ||
| return { | ||
| "logits": np.random.rand(X_test.shape[0], 5000), | ||
| "mean": np.random.rand(X_test.shape[0]), | ||
| "median": np.random.rand(X_test.shape[0]), | ||
| "mode": np.random.rand(X_test.shape[0]), | ||
| "quantiles": np.random.rand(3, X_test.shape[0]), | ||
| "borders": np.random.rand(5001), | ||
| "ei": np.random.rand(X_test.shape[0]), | ||
| "pi": np.random.rand(X_test.shape[0]), | ||
| } | ||
| return None | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def mock_mode(): | ||
| """Context manager that enables mock mode in the current thread.""" | ||
| set_mock_cost(0.0) | ||
| start_time = time.time() | ||
| set_mock_time(start_time) | ||
|
|
||
| # Store original logging levels for all loggers | ||
| loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] | ||
| loggers.append(logging.getLogger()) | ||
| original_levels = {logger: logger.level for logger in loggers} | ||
|
|
||
| # Overwrite actual fit and predict functions with mock functions | ||
| if get_is_local_tabpfn(): | ||
| from tabpfn import TabPFNClassifier, TabPFNRegressor | ||
|
|
||
| original_fit_classification = TabPFNClassifier.fit | ||
| original_fit_regressor = TabPFNRegressor.fit | ||
| original_predict_classification = TabPFNClassifier.predict | ||
| original_predict_proba_classification = TabPFNClassifier.predict_proba | ||
| original_predict_regressor = TabPFNRegressor.predict | ||
| TabPFNClassifier.fit = mock_fit_local | ||
| TabPFNClassifier.predict = mock_predict_local | ||
| TabPFNClassifier.predict_proba = mock_predict_proba_local | ||
| TabPFNRegressor.fit = mock_fit_local | ||
| TabPFNRegressor.predict = mock_predict_proba_local | ||
| else: | ||
| from tabpfn_client.service_wrapper import InferenceClient | ||
|
|
||
| original_fit = InferenceClient.fit | ||
| original_predict = InferenceClient.predict | ||
| InferenceClient.fit = classmethod(mock_fit_client) | ||
| InferenceClient.predict = classmethod(mock_predict) | ||
|
|
||
| # Suppress all warnings and logging | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore") | ||
| # Set all loggers to ERROR level | ||
| for logger in loggers: | ||
| logger.setLevel(logging.ERROR) | ||
|
|
||
| with mock.patch("time.time", side_effect=get_mock_time): | ||
| try: | ||
| yield lambda: (get_mock_time() - start_time, get_mock_cost()) | ||
| finally: | ||
| if get_is_local_tabpfn(): | ||
| from tabpfn.classifier import TabPFNClassifier | ||
| from tabpfn.regressor import TabPFNRegressor | ||
|
|
||
| TabPFNClassifier.fit = original_fit_classification | ||
| TabPFNClassifier.predict = original_predict_classification | ||
| TabPFNClassifier.predict_proba = ( | ||
| original_predict_proba_classification | ||
| ) | ||
| TabPFNRegressor.fit = original_fit_regressor | ||
| TabPFNRegressor.predict = original_predict_regressor | ||
| else: | ||
| from tabpfn_client.service_wrapper import InferenceClient | ||
|
|
||
| InferenceClient.fit = original_fit | ||
| InferenceClient.predict = original_predict | ||
|
|
||
| # Restore original logging levels | ||
| for logger in loggers: | ||
| logger.setLevel(original_levels[logger]) | ||
|
|
||
|
|
||
| def simulate_first(func): | ||
| """Decorator that first runs the decorated function in mock mode to simulate its duration | ||
| and credit usage. If client is used, only executes function if enough credits are available. | ||
| """ | ||
|
|
||
| @functools.wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| set_is_local_tabpfn() | ||
| with mock_mode() as get_simulation_results: | ||
| func(*args, **kwargs) | ||
| time_estimate, credit_estimate = get_simulation_results() | ||
|
|
||
| if not get_is_local_tabpfn(): | ||
| from tabpfn_client import get_access_token | ||
| from tabpfn_client.client import ServiceClient | ||
|
|
||
| access_token = get_access_token() | ||
| api_usage = ServiceClient.get_api_usage(access_token) | ||
|
|
||
| if ( | ||
| not api_usage["usage_limit"] == -1 | ||
| and api_usage["usage_limit"] - api_usage["current_usage"] | ||
| < credit_estimate | ||
| ): | ||
| raise RuntimeError( | ||
| f"Not enough credits left. Estimated credit usage: {credit_estimate}, credits left: {api_usage['usage_limit'] - api_usage['current_usage']}", | ||
| ) | ||
| print("Enough credits left.") # noqa: T201 | ||
|
|
||
| print( # noqa: T201 | ||
| f"Estimated duration: {time_estimate:.1f} seconds {'(on GPU)' if get_is_local_tabpfn() else ''}", | ||
| ) | ||
|
|
||
| return func(*args, **kwargs) | ||
|
|
||
| return wrapper | ||
File renamed without changes.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simulate_first decorator executes the decorated function twice (once in simulation mode and again for the actual call). This could lead to unintended side effects if the function is not idempotent; consider either ensuring that the function is side-effect free during simulation or refactoring to avoid duplicate execution.