-
Notifications
You must be signed in to change notification settings - Fork 621
[WIP] Batch size predict #792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
64b60cf
44c94f6
adcf1e8
0fea4c2
04c6a79
ab68558
bda6532
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Added automatic batching option in predict() functions. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |
| get_embeddings, | ||
| initialize_model_variables_helper, | ||
| initialize_telemetry, | ||
| set_multiquery_item_attention, | ||
| ) | ||
| from tabpfn.constants import ( | ||
| PROBABILITY_EPSILON_ROUND_ZERO, | ||
|
|
@@ -379,8 +380,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this | |
| False and True. | ||
|
|
||
| !!! warning | ||
| This does not batch the original input data. We still recommend to | ||
| batch the test set as necessary if you run out of memory. | ||
| This does not batch the original input data. If you run out of | ||
| memory during prediction, use `batch_size_predict` in the predict | ||
| method to automatically batch the test set. | ||
|
|
||
| random_state: | ||
| Controls the randomness of the model. Pass an int for reproducible | ||
|
|
@@ -1018,6 +1020,8 @@ def _raw_predict( | |
| *, | ||
| return_logits: bool, | ||
| return_raw_logits: bool = False, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Internal method to run prediction. | ||
|
|
||
|
|
@@ -1032,6 +1036,13 @@ def _raw_predict( | |
| post-processing steps. | ||
| return_raw_logits: If True, returns the raw logits without | ||
| averaging estimators or temperature scaling. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
|
||
| Returns: | ||
| The raw torch.Tensor output, either logits or probabilities, | ||
|
|
@@ -1052,6 +1063,16 @@ def _raw_predict( | |
| ord_encoder=getattr(self, "ordinal_encoder_", None), | ||
| ) | ||
|
|
||
| # If batch_size_predict is set, batch the predictions | ||
| if batch_size_predict is not None: | ||
| return self._batched_raw_predict( | ||
| X, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| return_logits=return_logits, | ||
| return_raw_logits=return_raw_logits, | ||
| ) | ||
|
|
||
| with handle_oom_errors(self.devices_, X, model_type="classifier"): | ||
| return self.forward( | ||
| X, | ||
|
|
@@ -1060,17 +1081,90 @@ def _raw_predict( | |
| return_raw_logits=return_raw_logits, | ||
| ) | ||
|
|
||
| def _batched_raw_predict( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int, | ||
| batch_predict_enable_test_interaction: bool, | ||
| return_logits: bool, | ||
| return_raw_logits: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Run batched prediction to avoid OOM on large test sets. | ||
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: The batch size for predictions. | ||
| batch_predict_enable_test_interaction: If False, test samples only | ||
| attend to training samples, ensuring predictions match unbatched. | ||
| If True, predictions may vary depending on batch size. | ||
| return_logits: If True, the logits are returned. | ||
| return_raw_logits: If True, returns the raw logits without | ||
| averaging estimators or temperature scaling. | ||
|
|
||
| Returns: | ||
| The concatenated predictions from all batches. | ||
| """ | ||
| # Disable multiquery attention for consistent predictions (matching unbatched) | ||
| # unless batch_predict_enable_test_interaction is True | ||
| if not batch_predict_enable_test_interaction: | ||
| set_multiquery_item_attention(self, enabled=False) | ||
|
|
||
| try: | ||
| results = [] | ||
| n_samples = X.shape[0] if hasattr(X, "shape") else len(X) | ||
|
|
||
| for start in range(0, n_samples, batch_size_predict): | ||
| end = min(start + batch_size_predict, n_samples) | ||
| X_batch = X[start:end] | ||
|
|
||
| with handle_oom_errors(self.devices_, X_batch, model_type="classifier"): | ||
| batch_result = self.forward( | ||
| X_batch, | ||
| use_inference_mode=True, | ||
| return_logits=return_logits, | ||
| return_raw_logits=return_raw_logits, | ||
| ) | ||
| results.append(batch_result) | ||
|
|
||
| # Concatenate along the appropriate dimension | ||
| # raw logits: (n_estimators, n_samples, n_classes) -> dim 1 | ||
| # logits/probas: (n_samples, n_classes) -> dim 0 | ||
| concat_dim = 1 if return_raw_logits else 0 | ||
| return torch.cat(results, dim=concat_dim) | ||
| finally: | ||
| # Restore multiquery attention if we disabled it | ||
| if not batch_predict_enable_test_interaction: | ||
| set_multiquery_item_attention(self, enabled=True) | ||
|
|
||
| @track_model_call(model_method="predict", param_names=["X"]) | ||
| def predict(self, X: XType) -> np.ndarray: | ||
| def predict( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> np.ndarray: | ||
| """Predict the class labels for the provided input samples. | ||
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size to avoid OOM errors. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
|
||
| Returns: | ||
| The predicted class labels as a NumPy array. | ||
| """ | ||
| probas = self._predict_proba(X=X) | ||
| probas = self._predict_proba( | ||
| X=X, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| ) | ||
| y_pred = np.argmax(probas, axis=1) | ||
| if hasattr(self, "label_encoder_") and self.label_encoder_ is not None: | ||
| return self.label_encoder_.inverse_transform(y_pred) | ||
|
|
@@ -1079,24 +1173,48 @@ def predict(self, X: XType) -> np.ndarray: | |
|
|
||
| @config_context(transform_output="default") | ||
| @track_model_call(model_method="predict", param_names=["X"]) | ||
| def predict_logits(self, X: XType) -> np.ndarray: | ||
| def predict_logits( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> np.ndarray: | ||
| """Predict the raw logits for the provided input samples. | ||
|
|
||
| Logits represent the unnormalized log-probabilities of the classes | ||
| before the softmax activation function is applied. | ||
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size to avoid OOM errors. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
Comment on lines
+1190
to
+1196
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstrings for |
||
|
|
||
| Returns: | ||
| The predicted logits as a NumPy array. Shape (n_samples, n_classes). | ||
| """ | ||
| logits_tensor = self._raw_predict(X, return_logits=True) | ||
| logits_tensor = self._raw_predict( | ||
| X, | ||
| return_logits=True, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| ) | ||
| return logits_tensor.float().detach().cpu().numpy() | ||
|
|
||
| @config_context(transform_output="default") | ||
| @track_model_call(model_method="predict", param_names=["X"]) | ||
| def predict_raw_logits(self, X: XType) -> np.ndarray: | ||
| def predict_raw_logits( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> np.ndarray: | ||
| """Predict the raw logits for the provided input samples. | ||
|
|
||
| Logits represent the unnormalized log-probabilities of the classes | ||
|
|
@@ -1106,6 +1224,13 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: | |
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size to avoid OOM errors. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
|
||
| Returns: | ||
| An array of predicted logits for each estimator, | ||
|
|
@@ -1115,37 +1240,78 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: | |
| X, | ||
| return_logits=False, | ||
| return_raw_logits=True, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| ) | ||
| return logits_tensor.float().detach().cpu().numpy() | ||
|
|
||
| @track_model_call(model_method="predict", param_names=["X"]) | ||
| def predict_proba(self, X: XType) -> np.ndarray: | ||
| def predict_proba( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> np.ndarray: | ||
| """Predict the probabilities of the classes for the provided input samples. | ||
|
|
||
| This is a wrapper around the `_predict_proba` method. | ||
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size to avoid OOM errors. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
|
||
| Returns: | ||
| The predicted probabilities of the classes as a NumPy array. | ||
| Shape (n_samples, n_classes). | ||
| """ | ||
| return self._predict_proba(X) | ||
| return self._predict_proba( | ||
| X, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| ) | ||
|
|
||
| @config_context(transform_output="default") # type: ignore | ||
| def _predict_proba(self, X: XType) -> np.ndarray: | ||
| def _predict_proba( | ||
| self, | ||
| X: XType, | ||
| *, | ||
| batch_size_predict: int | None = None, | ||
| batch_predict_enable_test_interaction: bool = False, | ||
| ) -> np.ndarray: | ||
| """Predict the probabilities of the classes for the provided input samples. | ||
|
|
||
| Args: | ||
| X: The input data for prediction. | ||
| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. | ||
|
|
||
| Returns: | ||
| The predicted probabilities of the classes as a NumPy array. | ||
| Shape (n_samples, n_classes). | ||
| """ | ||
| probas = ( | ||
| self._raw_predict(X, return_logits=False).float().detach().cpu().numpy() | ||
| self._raw_predict( | ||
| X, | ||
| return_logits=False, | ||
| batch_size_predict=batch_size_predict, | ||
| batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, | ||
| ) | ||
| .float() | ||
| .detach() | ||
| .cpu() | ||
| .numpy() | ||
| ) | ||
| probas = self._maybe_reweight_probas(probas=probas) | ||
| if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,19 +63,13 @@ def __init__( | |
| n_test_samples: int | None = None, | ||
| model_type: str = "classifier", | ||
| ): | ||
| predict_method = "predict_proba" if model_type == "classifier" else "predict" | ||
|
|
||
| size_info = f" with {n_test_samples:,} test samples" if n_test_samples else "" | ||
|
|
||
| model_class = f"TabPFN{model_type.title()}" | ||
| message = ( | ||
| f"{self.device_name} out of memory{size_info}.\n\n" | ||
| f"Solution: Split your test data into smaller batches:\n\n" | ||
| f" batch_size = 1000 # depends on hardware\n" | ||
| f" predictions = []\n" | ||
| f" for i in range(0, len(X_test), batch_size):\n" | ||
| f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" | ||
| f" predictions.append(batch)\n" | ||
| f" predictions = np.vstack(predictions)" | ||
| f"Solution: Set batch_size_predict when creating the model:\n\n" | ||
| f" model = {model_class}(batch_size_predict=1000)" | ||
| ) | ||
|
Comment on lines
+68
to
73
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The suggested solution in the OOM error message is incorrect. message = (
f"{self.device_name} out of memory{size_info}.\n\n"
f"Solution: Set `batch_size_predict` in the `predict()` or `predict_proba()` method, for example:\n\n"
f" predictions = model.predict(X_test, batch_size_predict=1000)"
) |
||
| if original_error is not None: | ||
| message += f"\n\nOriginal error: {original_error}" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function
set_multiquery_item_attentionassumes thatmodel.executor_always has amodel_cachesattribute. However,SingleDeviceInferenceEngineand its subclasses (such asInferenceEngineBatchedNoPreprocessingused infit_mode="batched"andInferenceEngineCacheKVused infit_mode="fit_with_cache") do not have this attribute; they use amodelsattribute instead.When
predict()is called withbatch_size_predictwhile the model is in one of these modes, anAttributeErrorwill be raised, causing a crash (Denial of Service).To fix this, check for the existence of
model_cachesor handle both engine types explicitly.