Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces an automatic batching option for prediction methods to prevent out-of-memory errors on large test sets. However, a critical Denial of Service vulnerability was identified in the set_multiquery_item_attention utility function. This function fails to account for all types of inference engines, specifically those that do not use a model cache, which will lead to an AttributeError when batching is used with certain model configurations (e.g., fit_mode='batched'). Additionally, the review includes a critical fix for an incorrect error message and a medium-severity suggestion to address duplicated docstrings for the new parameters to improve long-term maintainability.
| model_class = f"TabPFN{model_type.title()}" | ||
| message = ( | ||
| f"{self.device_name} out of memory{size_info}.\n\n" | ||
| f"Solution: Split your test data into smaller batches:\n\n" | ||
| f" batch_size = 1000 # depends on hardware\n" | ||
| f" predictions = []\n" | ||
| f" for i in range(0, len(X_test), batch_size):\n" | ||
| f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" | ||
| f" predictions.append(batch)\n" | ||
| f" predictions = np.vstack(predictions)" | ||
| f"Solution: Set batch_size_predict when creating the model:\n\n" | ||
| f" model = {model_class}(batch_size_predict=1000)" | ||
| ) |
There was a problem hiding this comment.
The suggested solution in the OOM error message is incorrect. batch_size_predict is a parameter for the predict() methods, not the TabPFN* constructor. This will confuse users trying to resolve out-of-memory issues. The message should guide them to use the parameter in the prediction call.
message = (
f"{self.device_name} out of memory{size_info}.\n\n"
f"Solution: Set `batch_size_predict` in the `predict()` or `predict_proba()` method, for example:\n\n"
f" predictions = model.predict(X_test, batch_size_predict=1000)"
)| def set_multiquery_item_attention( | ||
| model: TabPFNClassifier | TabPFNRegressor, | ||
| *, | ||
| enabled: bool, | ||
| ) -> None: | ||
| """Set multiquery_item_attention_for_test_set on all model layers. | ||
|
|
||
| This controls whether test samples attend to each other during inference. | ||
| Disabling it ensures predictions are consistent across different batch sizes. | ||
|
|
||
| Args: | ||
| model: The fitted TabPFN model. | ||
| enabled: If True, test samples can attend to each other. | ||
| If False, test samples only attend to training samples. | ||
| """ | ||
| for model_cache in model.executor_.model_caches: | ||
| for m in model_cache._models.values(): | ||
| for module in m.modules(): | ||
| if hasattr(module, "multiquery_item_attention_for_test_set"): | ||
| module.multiquery_item_attention_for_test_set = enabled |
There was a problem hiding this comment.
The function set_multiquery_item_attention assumes that model.executor_ always has a model_caches attribute. However, SingleDeviceInferenceEngine and its subclasses (such as InferenceEngineBatchedNoPreprocessing used in fit_mode="batched" and InferenceEngineCacheKV used in fit_mode="fit_with_cache") do not have this attribute; they use a models attribute instead.
When predict() is called with batch_size_predict while the model is in one of these modes, an AttributeError will be raised, causing a crash (Denial of Service).
To fix this, check for the existence of model_caches or handle both engine types explicitly.
def set_multiquery_item_attention(
model: TabPFNClassifier | TabPFNRegressor,
*,
enabled: bool,
) -> None:
"""Set multiquery_item_attention_for_test_set on all model layers.
This controls whether test samples attend to each other during inference.
Disabling it ensures predictions are consistent across different batch sizes.
Args:
model: The fitted TabPFN model.
enabled: If True, test samples can attend to each other.
If False, test samples only attend to training samples.
"""
executor = model.executor_
if hasattr(executor, "model_caches"):
for model_cache in executor.model_caches:
for m in model_cache._models.values():
for module in m.modules():
if hasattr(module, "multiquery_item_attention_for_test_set"):
module.multiquery_item_attention_for_test_set = enabled
elif hasattr(executor, "models"):
for m in executor.models:
for module in m.modules():
if hasattr(module, "multiquery_item_attention_for_test_set"):
module.multiquery_item_attention_for_test_set = enabled| batch_size_predict: If set, predictions are batched into chunks | ||
| of this size to avoid OOM errors. If None, no batching is performed. | ||
| batch_predict_enable_test_interaction: If False (default), test | ||
| samples only attend to training samples during batched prediction, | ||
| ensuring predictions match unbatched. If True, test samples can | ||
| attend to each other within a batch, so predictions may vary | ||
| depending on batch size. |
There was a problem hiding this comment.
The docstrings for batch_size_predict and batch_predict_enable_test_interaction are duplicated across multiple predict* methods (predict, predict_logits, predict_raw_logits, predict_proba, _predict_proba, and _raw_predict). This could lead to inconsistencies if they need to be updated in the future. Consider centralizing this documentation, for example by using a shared docstring template, to improve maintainability.
| batch_size_predict: | ||
| If set, predictions are batched into chunks of this size to avoid | ||
| OOM errors. If None, no batching is performed. | ||
|
|
||
| batch_predict_enable_test_interaction: | ||
| If False (default), test samples only attend to training samples | ||
| during batched prediction, ensuring predictions match unbatched. | ||
| If True, test samples can attend to each other within a batch, | ||
| so predictions may vary depending on batch size. |
There was a problem hiding this comment.
The docstrings for batch_size_predict and batch_predict_enable_test_interaction are duplicated across the predict method's overloads and implementation. This could lead to inconsistencies if they need to be updated in the future. Consider centralizing this documentation, for example by using a shared docstring template, to improve maintainability. This is similar to an issue in classifier.py.
No description provided.