Skip to content

Comments

[WIP] Batch size predict#792

Closed
klemens-floege wants to merge 7 commits intomainfrom
batch_size_predict
Closed

[WIP] Batch size predict#792
klemens-floege wants to merge 7 commits intomainfrom
batch_size_predict

Conversation

@klemens-floege
Copy link
Contributor

No description provided.

@klemens-floege klemens-floege requested a review from a team as a code owner February 18, 2026 18:46
@klemens-floege klemens-floege requested review from noahho and removed request for a team February 18, 2026 18:46
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@klemens-floege klemens-floege removed the request for review from noahho February 18, 2026 18:46
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an automatic batching option for prediction methods to prevent out-of-memory errors on large test sets. However, a critical Denial of Service vulnerability was identified in the set_multiquery_item_attention utility function. This function fails to account for all types of inference engines, specifically those that do not use a model cache, which will lead to an AttributeError when batching is used with certain model configurations (e.g., fit_mode='batched'). Additionally, the review includes a critical fix for an incorrect error message and a medium-severity suggestion to address duplicated docstrings for the new parameters to improve long-term maintainability.

Comment on lines +68 to 73
model_class = f"TabPFN{model_type.title()}"
message = (
f"{self.device_name} out of memory{size_info}.\n\n"
f"Solution: Split your test data into smaller batches:\n\n"
f" batch_size = 1000 # depends on hardware\n"
f" predictions = []\n"
f" for i in range(0, len(X_test), batch_size):\n"
f" batch = model.{predict_method}(X_test[i:i + batch_size])\n"
f" predictions.append(batch)\n"
f" predictions = np.vstack(predictions)"
f"Solution: Set batch_size_predict when creating the model:\n\n"
f" model = {model_class}(batch_size_predict=1000)"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The suggested solution in the OOM error message is incorrect. batch_size_predict is a parameter for the predict() methods, not the TabPFN* constructor. This will confuse users trying to resolve out-of-memory issues. The message should guide them to use the parameter in the prediction call.

        message = (
            f"{self.device_name} out of memory{size_info}.\n\n"
            f"Solution: Set `batch_size_predict` in the `predict()` or `predict_proba()` method, for example:\n\n"
            f"    predictions = model.predict(X_test, batch_size_predict=1000)"
        )

Comment on lines +433 to +452
def set_multiquery_item_attention(
model: TabPFNClassifier | TabPFNRegressor,
*,
enabled: bool,
) -> None:
"""Set multiquery_item_attention_for_test_set on all model layers.

This controls whether test samples attend to each other during inference.
Disabling it ensures predictions are consistent across different batch sizes.

Args:
model: The fitted TabPFN model.
enabled: If True, test samples can attend to each other.
If False, test samples only attend to training samples.
"""
for model_cache in model.executor_.model_caches:
for m in model_cache._models.values():
for module in m.modules():
if hasattr(module, "multiquery_item_attention_for_test_set"):
module.multiquery_item_attention_for_test_set = enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The function set_multiquery_item_attention assumes that model.executor_ always has a model_caches attribute. However, SingleDeviceInferenceEngine and its subclasses (such as InferenceEngineBatchedNoPreprocessing used in fit_mode="batched" and InferenceEngineCacheKV used in fit_mode="fit_with_cache") do not have this attribute; they use a models attribute instead.

When predict() is called with batch_size_predict while the model is in one of these modes, an AttributeError will be raised, causing a crash (Denial of Service).

To fix this, check for the existence of model_caches or handle both engine types explicitly.

def set_multiquery_item_attention(
    model: TabPFNClassifier | TabPFNRegressor,
    *,
    enabled: bool,
) -> None:
    """Set multiquery_item_attention_for_test_set on all model layers.

    This controls whether test samples attend to each other during inference.
    Disabling it ensures predictions are consistent across different batch sizes.

    Args:
        model: The fitted TabPFN model.
        enabled: If True, test samples can attend to each other.
            If False, test samples only attend to training samples.
    """
    executor = model.executor_
    if hasattr(executor, "model_caches"):
        for model_cache in executor.model_caches:
            for m in model_cache._models.values():
                for module in m.modules():
                    if hasattr(module, "multiquery_item_attention_for_test_set"):
                        module.multiquery_item_attention_for_test_set = enabled
    elif hasattr(executor, "models"):
        for m in executor.models:
            for module in m.modules():
                if hasattr(module, "multiquery_item_attention_for_test_set"):
                    module.multiquery_item_attention_for_test_set = enabled

Comment on lines +1190 to +1196
batch_size_predict: If set, predictions are batched into chunks
of this size to avoid OOM errors. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for batch_size_predict and batch_predict_enable_test_interaction are duplicated across multiple predict* methods (predict, predict_logits, predict_raw_logits, predict_proba, _predict_proba, and _raw_predict). This could lead to inconsistencies if they need to be updated in the future. Consider centralizing this documentation, for example by using a shared docstring template, to improve maintainability.

Comment on lines +901 to +909
batch_size_predict:
If set, predictions are batched into chunks of this size to avoid
OOM errors. If None, no batching is performed.

batch_predict_enable_test_interaction:
If False (default), test samples only attend to training samples
during batched prediction, ensuring predictions match unbatched.
If True, test samples can attend to each other within a batch,
so predictions may vary depending on batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for batch_size_predict and batch_predict_enable_test_interaction are duplicated across the predict method's overloads and implementation. This could lead to inconsistencies if they need to be updated in the future. Consider centralizing this documentation, for example by using a shared docstring template, to improve maintainability. This is similar to an issue in classifier.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant