Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
The pull request introduces batched prediction functionality to the TabPFNClassifier and TabPFNRegressor classes, along with a new predict_in_batches utility function. This is a valuable addition for handling large datasets more efficiently and preventing out-of-memory errors. The changes also update the TabPFNOutOfMemoryError message to guide users towards the new batch_size_predict parameter. New tests have been added to ensure the batched predictions match the unbatched results across various output types. Overall, the changes are well-implemented and improve the usability of the library for large-scale inference.
| The predicted class labels as a NumPy array. | ||
| """ | ||
| if batch_size_predict is not None: | ||
| return predict_in_batches(self.predict, X, batch_size_predict) |
There was a problem hiding this comment.
When predict_in_batches is called recursively within the predict method, it passes self.predict as the predict_fn. This can lead to infinite recursion if not handled carefully. While Python's default recursion limit might catch this, it's safer to pass self._predict_proba and then handle the argmax and inverse_transform outside the predict_in_batches call, or ensure that the predict_fn passed to predict_in_batches does not itself call predict_in_batches with the same batch_size_predict parameter.
| return predict_in_batches(self.predict, X, batch_size_predict) | |
| return predict_in_batches( | |
| lambda chunk: np.argmax(self._predict_proba(chunk), axis=1), | |
| X, | |
| batch_size_predict, | |
| concat_fn=lambda results: self.label_encoder_.inverse_transform(np.concatenate(results, axis=0)) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None else np.concatenate(results, axis=0) | |
| ) |
| The predicted logits as a NumPy array. Shape (n_samples, n_classes). | ||
| """ | ||
| if batch_size_predict is not None: | ||
| return predict_in_batches(self.predict_logits, X, batch_size_predict) |
There was a problem hiding this comment.
Similar to the predict method, passing self.predict_logits recursively to predict_in_batches can lead to infinite recursion. It's better to pass the underlying _raw_predict function and then apply the necessary post-processing.
| return predict_in_batches(self.predict_logits, X, batch_size_predict) | |
| return predict_in_batches( | |
| lambda chunk: self._raw_predict(chunk, return_logits=True).float().detach().cpu().numpy(), | |
| X, | |
| batch_size_predict | |
| ) |
| if batch_size_predict is not None: | ||
| return predict_in_batches( | ||
| self.predict_raw_logits, | ||
| X, | ||
| batch_size_predict, | ||
| concat_fn=lambda results: np.concatenate(results, axis=1), | ||
| ) |
There was a problem hiding this comment.
The predict_raw_logits method passes self.predict_raw_logits recursively to predict_in_batches. This can lead to infinite recursion. Instead, the _raw_predict method should be called directly within the lambda function.
return predict_in_batches(
lambda chunk: self._raw_predict(chunk, return_logits=False, return_raw_logits=True).float().detach().cpu().numpy(),
X,
batch_size_predict,
concat_fn=lambda results: np.concatenate(results, axis=1),
)| Shape (n_samples, n_classes). | ||
| """ | ||
| if batch_size_predict is not None: | ||
| return predict_in_batches(self.predict_proba, X, batch_size_predict) |
There was a problem hiding this comment.
| if batch_size_predict is not None: | ||
| return predict_in_batches( | ||
| lambda chunk: self.predict( | ||
| chunk, output_type=output_type, quantiles=quantiles | ||
| ), | ||
| X, | ||
| batch_size_predict, | ||
| concat_fn=lambda results: _concatenate_regression_results( | ||
| results, output_type | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The predict method recursively calls itself within predict_in_batches. This can lead to infinite recursion if not handled carefully. It's safer to pass a lambda that directly calls the internal logic for prediction (e.g., the part after the if batch_size_predict is not None: block) rather than calling self.predict again.
return predict_in_batches(
lambda chunk: self._predict_internal(chunk, output_type=output_type, quantiles=quantiles), # Assuming _predict_internal encapsulates the core logic
X,
batch_size_predict,
concat_fn=lambda results: _concatenate_regression_results(
results, output_type
),
)| np.concatenate([r[q] for r in results], axis=0) | ||
| for q in range(len(results[0])) |
There was a problem hiding this comment.
In the _concatenate_regression_results function, when output_type == "quantiles", the code accesses r[q] where r is an element from results (which is RegressionResultType) and q is an integer. This implies r is expected to be a list or array of quantiles. However, RegressionResultType can also be np.ndarray, MainOutputDict, or FullOutputDict. This could lead to a TypeError if r is not indexable by an integer q (e.g., if it's a np.ndarray representing a single output type like 'mean'). It should explicitly check the type or ensure results always contains lists of quantiles when output_type is 'quantiles'.
return [
np.concatenate([typing.cast(list[np.ndarray], r)[q] for r in results], axis=0)
for q in range(len(typing.cast(list[np.ndarray], results[0])))
]| mean=np.concatenate([r["mean"] for r in results], axis=0), | ||
| median=np.concatenate([r["median"] for r in results], axis=0), | ||
| mode=np.concatenate([r["mode"] for r in results], axis=0), | ||
| quantiles=[ | ||
| np.concatenate([r["quantiles"][q] for r in results], axis=0) | ||
| for q in range(len(results[0]["quantiles"])) | ||
| ], |
There was a problem hiding this comment.
In the _concatenate_regression_results function, when output_type is 'main' or 'full', the code accesses r["mean"], r["median"], r["mode"], and r["quantiles"] from r which is an element of results. results is a list[RegressionResultType]. If RegressionResultType is np.ndarray (e.g., if the original predict call returned only 'mean'), then accessing r["mean"] would raise a TypeError. The function should ensure that results contains MainOutputDict or FullOutputDict when processing 'main' or 'full' output types, or handle the np.ndarray case gracefully.
main = MainOutputDict(
mean=np.concatenate([typing.cast(MainOutputDict, r)["mean"] for r in results], axis=0),
median=np.concatenate([typing.cast(MainOutputDict, r)["median"] for r in results], axis=0),
mode=np.concatenate([typing.cast(MainOutputDict, r)["mode"] for r in results], axis=0),
quantiles=[
np.concatenate([typing.cast(MainOutputDict, r)["quantiles"][q] for r in results], axis=0)
for q in range(len(typing.cast(MainOutputDict, results[0])["quantiles"])) # Cast results[0] to MainOutputDict
],
)| return main | ||
| return FullOutputDict( | ||
| **main, | ||
| criterion=results[0]["criterion"], |
There was a problem hiding this comment.
Similar to the previous comment, r["logits"] assumes r is a FullOutputDict. If r is an np.ndarray or MainOutputDict, this will cause an error. Type casting r to FullOutputDict would make this access safe.
criterion=typing.cast(FullOutputDict, results[0])["criterion"],
logits=torch.cat([typing.cast(FullOutputDict, r)["logits"] for r in results], dim=0),| def predict_in_batches( | ||
| predict_fn: typing.Callable, | ||
| X: XType, | ||
| batch_size: int, | ||
| concat_fn: typing.Callable | None = None, | ||
| ) -> typing.Any: |
There was a problem hiding this comment.
The predict_in_batches function is a good addition for handling large datasets. However, the type hint for X is XType, which is a generic Any. It would be more precise to use np.ndarray or torch.Tensor as X is indexed directly, which implies it's an array-like object. This improves type safety and readability.
| def predict_in_batches( | |
| predict_fn: typing.Callable, | |
| X: XType, | |
| batch_size: int, | |
| concat_fn: typing.Callable | None = None, | |
| ) -> typing.Any: | |
| def predict_in_batches( | |
| predict_fn: typing.Callable, | |
| X: np.ndarray, # More specific type hint | |
| batch_size: int, | |
| concat_fn: typing.Callable | None = None, | |
| ) -> typing.Any: |
| f"Solution: Use batch_size_predict to split test data" | ||
| f" into smaller batches:\n\n" | ||
| f" model.{predict_method}(" | ||
| f"X_test, batch_size_predict=100)" |
There was a problem hiding this comment.
The example code in the error message uses batch_size_predict=100. While this is a reasonable default, it might be beneficial to mention that the optimal batch_size_predict depends on hardware and the specific dataset, similar to how the previous message suggested batch_size = 1000 # depends on hardware. This would provide more comprehensive guidance to the user.
f
No description provided.