Skip to content

Comments

[WIP] Batched Predictions#793

Open
klemens-floege wants to merge 2 commits intomainfrom
klemens/batched-predictions
Open

[WIP] Batched Predictions#793
klemens-floege wants to merge 2 commits intomainfrom
klemens/batched-predictions

Conversation

@klemens-floege
Copy link
Contributor

No description provided.

@klemens-floege klemens-floege requested a review from a team as a code owner February 18, 2026 18:48
@klemens-floege klemens-floege requested review from priorphil and removed request for a team February 18, 2026 18:48
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@klemens-floege klemens-floege removed the request for review from priorphil February 18, 2026 18:48
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces batched prediction functionality to the TabPFNClassifier and TabPFNRegressor classes, along with a new predict_in_batches utility function. This is a valuable addition for handling large datasets more efficiently and preventing out-of-memory errors. The changes also update the TabPFNOutOfMemoryError message to guide users towards the new batch_size_predict parameter. New tests have been added to ensure the batched predictions match the unbatched results across various output types. Overall, the changes are well-implemented and improve the usability of the library for large-scale inference.

The predicted class labels as a NumPy array.
"""
if batch_size_predict is not None:
return predict_in_batches(self.predict, X, batch_size_predict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When predict_in_batches is called recursively within the predict method, it passes self.predict as the predict_fn. This can lead to infinite recursion if not handled carefully. While Python's default recursion limit might catch this, it's safer to pass self._predict_proba and then handle the argmax and inverse_transform outside the predict_in_batches call, or ensure that the predict_fn passed to predict_in_batches does not itself call predict_in_batches with the same batch_size_predict parameter.

Suggested change
return predict_in_batches(self.predict, X, batch_size_predict)
return predict_in_batches(
lambda chunk: np.argmax(self._predict_proba(chunk), axis=1),
X,
batch_size_predict,
concat_fn=lambda results: self.label_encoder_.inverse_transform(np.concatenate(results, axis=0)) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None else np.concatenate(results, axis=0)
)

The predicted logits as a NumPy array. Shape (n_samples, n_classes).
"""
if batch_size_predict is not None:
return predict_in_batches(self.predict_logits, X, batch_size_predict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the predict method, passing self.predict_logits recursively to predict_in_batches can lead to infinite recursion. It's better to pass the underlying _raw_predict function and then apply the necessary post-processing.

Suggested change
return predict_in_batches(self.predict_logits, X, batch_size_predict)
return predict_in_batches(
lambda chunk: self._raw_predict(chunk, return_logits=True).float().detach().cpu().numpy(),
X,
batch_size_predict
)

Comment on lines +1131 to +1137
if batch_size_predict is not None:
return predict_in_batches(
self.predict_raw_logits,
X,
batch_size_predict,
concat_fn=lambda results: np.concatenate(results, axis=1),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The predict_raw_logits method passes self.predict_raw_logits recursively to predict_in_batches. This can lead to infinite recursion. Instead, the _raw_predict method should be called directly within the lambda function.

            return predict_in_batches(
                lambda chunk: self._raw_predict(chunk, return_logits=False, return_raw_logits=True).float().detach().cpu().numpy(),
                X,
                batch_size_predict,
                concat_fn=lambda results: np.concatenate(results, axis=1),
            )

Shape (n_samples, n_classes).
"""
if batch_size_predict is not None:
return predict_in_batches(self.predict_proba, X, batch_size_predict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing self.predict_proba recursively to predict_in_batches can cause infinite recursion. It's more robust to call the internal _predict_proba method directly within the lambda function.

            return predict_in_batches(lambda chunk: self._predict_proba(chunk), X, batch_size_predict)

Comment on lines +898 to +908
if batch_size_predict is not None:
return predict_in_batches(
lambda chunk: self.predict(
chunk, output_type=output_type, quantiles=quantiles
),
X,
batch_size_predict,
concat_fn=lambda results: _concatenate_regression_results(
results, output_type
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The predict method recursively calls itself within predict_in_batches. This can lead to infinite recursion if not handled carefully. It's safer to pass a lambda that directly calls the internal logic for prediction (e.g., the part after the if batch_size_predict is not None: block) rather than calling self.predict again.

            return predict_in_batches(
                lambda chunk: self._predict_internal(chunk, output_type=output_type, quantiles=quantiles), # Assuming _predict_internal encapsulates the core logic
                X,
                batch_size_predict,
                concat_fn=lambda results: _concatenate_regression_results(
                    results, output_type
                ),
            )

Comment on lines +1265 to +1266
np.concatenate([r[q] for r in results], axis=0)
for q in range(len(results[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the _concatenate_regression_results function, when output_type == "quantiles", the code accesses r[q] where r is an element from results (which is RegressionResultType) and q is an integer. This implies r is expected to be a list or array of quantiles. However, RegressionResultType can also be np.ndarray, MainOutputDict, or FullOutputDict. This could lead to a TypeError if r is not indexable by an integer q (e.g., if it's a np.ndarray representing a single output type like 'mean'). It should explicitly check the type or ensure results always contains lists of quantiles when output_type is 'quantiles'.

        return [
            np.concatenate([typing.cast(list[np.ndarray], r)[q] for r in results], axis=0)
            for q in range(len(typing.cast(list[np.ndarray], results[0])))
        ]

Comment on lines +1271 to +1277
mean=np.concatenate([r["mean"] for r in results], axis=0),
median=np.concatenate([r["median"] for r in results], axis=0),
mode=np.concatenate([r["mode"] for r in results], axis=0),
quantiles=[
np.concatenate([r["quantiles"][q] for r in results], axis=0)
for q in range(len(results[0]["quantiles"]))
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the _concatenate_regression_results function, when output_type is 'main' or 'full', the code accesses r["mean"], r["median"], r["mode"], and r["quantiles"] from r which is an element of results. results is a list[RegressionResultType]. If RegressionResultType is np.ndarray (e.g., if the original predict call returned only 'mean'), then accessing r["mean"] would raise a TypeError. The function should ensure that results contains MainOutputDict or FullOutputDict when processing 'main' or 'full' output types, or handle the np.ndarray case gracefully.

        main = MainOutputDict(
            mean=np.concatenate([typing.cast(MainOutputDict, r)["mean"] for r in results], axis=0),
            median=np.concatenate([typing.cast(MainOutputDict, r)["median"] for r in results], axis=0),
            mode=np.concatenate([typing.cast(MainOutputDict, r)["mode"] for r in results], axis=0),
            quantiles=[
                np.concatenate([typing.cast(MainOutputDict, r)["quantiles"][q] for r in results], axis=0)
                for q in range(len(typing.cast(MainOutputDict, results[0])["quantiles"])) # Cast results[0] to MainOutputDict
            ],
        )

return main
return FullOutputDict(
**main,
criterion=results[0]["criterion"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, r["logits"] assumes r is a FullOutputDict. If r is an np.ndarray or MainOutputDict, this will cause an error. Type casting r to FullOutputDict would make this access safe.

            criterion=typing.cast(FullOutputDict, results[0])["criterion"],
            logits=torch.cat([typing.cast(FullOutputDict, r)["logits"] for r in results], dim=0),

Comment on lines +494 to +499
def predict_in_batches(
predict_fn: typing.Callable,
X: XType,
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The predict_in_batches function is a good addition for handling large datasets. However, the type hint for X is XType, which is a generic Any. It would be more precise to use np.ndarray or torch.Tensor as X is indexed directly, which implies it's an array-like object. This improves type safety and readability.

Suggested change
def predict_in_batches(
predict_fn: typing.Callable,
X: XType,
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:
def predict_in_batches(
predict_fn: typing.Callable,
X: np.ndarray, # More specific type hint
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:

Comment on lines +72 to +75
f"Solution: Use batch_size_predict to split test data"
f" into smaller batches:\n\n"
f" model.{predict_method}("
f"X_test, batch_size_predict=100)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example code in the error message uses batch_size_predict=100. While this is a reasonable default, it might be beneficial to mention that the optimal batch_size_predict depends on hardware and the specific dataset, similar to how the previous message suggested batch_size = 1000 # depends on hardware. This would provide more comprehensive guidance to the user.

            f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant