reduce linear operator overhead in exact marginal log likelihood computation#2682
reduce linear operator overhead in exact marginal log likelihood computation#2682kayween wants to merge 6 commits intocornellius-gp:mainfrom
Conversation
|
cc @saitcakmak who has done some investigations in the past into removing LinearOperator overhead. I'll let him give the details, but IIRC the high level tl;dr was that while we did see speedups those didn't amount to dramatic savings in the grand scheme of things for our typical use cases. But our use cases aren't necessarily standard so I think there could be meaningful value in allowing to short-circuit the linear operator overhead in the small data regime. |
|
I think this is great! What I had done was to take out linear operator and any parts of GPyTorch that we didn't need for ExactGPs and create a bare-bone version of the library (~20% of GPyTorch and no linear operator). I had seen up to 2x faster execution for some model operations, but at the end of the day, it wasn't that significant when considered as part of the whole BO loop in Ax. I think introducing small (depending on where we measure from) improvements like this to core GPyTorch is great. I'd definitely use this for ExactGPs in BoTorch. |
|
I had a question while implementing this PR, which I think worth some discussions. Note that This PR does implement a custom backward pass for log determinant by In terms of the running time, this custom backward pass for logdet is indeed faster than PyTorch's default backward pass. Numerically, computing the inverse by two triangular solves shouldn't be bad since the backward pass of Cholesky decomposition requires two triangular solves anyway. (I am assuming PyTorch implements the Cholesky backward pass by something similar to Eq (9) in Iain Murray's note.) |
|
I am not sure if there were any special considerations at the time - @gpleiss or @jacobrgardner might know? |
|
I dug around and the only remotely related issue to cholesky derivatives I could find was this one pytorch/pytorch#18825. I think we just assumed the pytorch default derivatives for these ops would be fast. I certainly have no problem with us merging a custom backward pass in the name of 10-40% speed ups. |
| _default = True | ||
|
|
||
|
|
||
| class use_torch_tensors(_feature_flag): |
There was a problem hiding this comment.
What do we think about making this on by default up to some N?
There was a problem hiding this comment.
Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference.
I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N.
|
@kayween One thing I'm noticing: the speedup is increasing with matrix size! That is unexpected: one would assume that the additional overhead of linear operator packaging becomes negligible as the matrix size increases. Can we extend the benchmark out to larger N? If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops. |
I dug along this line a bit. @jacobrgardner and anyone who is curious, this is how the current PyTorch implements the Cholesky backward pass, which involves two triangular solves and a matmul. I haven't dug into |
|
@jacobrgardner Here are more benchmark results for larger N. The code generating the results is available in this gist. As with before, all benchmark runs use Cholesky decomposition for GP training. Running time (in seconds) on synthetic datasets
The time reduction doesn't seem to saturate as we increase the size of training data! For example, we're seeing about 61% time reduction on CPU when
I don't think the issue is necessarily with PyTorch ops. There are probably two explanations why the speed-up does not stop with larger dataset sizes.
|
@kayween this doesn't make sense to me. The representation tree at the end of the day should just be carrying around the underlying pytorch tensors as pointers and reconstructing the thin python wrapper. Like, in a vanilla GP, we have an All that is to say: there's no way that python overhead, which should essentially be constant, dominates the linear algebra being done by
Sure, this makes sense to me. |
|
@jacobrgardner This following code snippet shows that linear operators still have about 30% overhead compared to PyTorch even at the scale of |
There was a problem hiding this comment.
Pull request overview
This PR introduces a performance optimization for exact marginal log likelihood computation by reducing linear operator overhead when working with moderate-sized dense matrices. The implementation allows users to opt into using direct tensor operations via a new use_torch_tensors setting, achieving runtime reductions of 6-44% across different hardware and precision configurations.
Key Changes:
- Introduces a custom
TensorInvQuadLogdetautograd function that computes inverse quadratic forms and log determinants directly on torch tensors using Cholesky decomposition - Adds a
use_torch_tensorsfeature flag to enable the optimization - Modifies
MultivariateNormal.log_probto use the tensor-based implementation when the flag is enabled and Cholesky decomposition is appropriate
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| gpytorch/functions/inv_quad_logdet.py | Implements TensorInvQuadLogdet, a custom autograd function for computing inverse quadratic forms and log determinants on tensors, bypassing linear operator overhead |
| gpytorch/functions/init.py | Exports the new TensorInvQuadLogdet function |
| gpytorch/distributions/multivariate_normal.py | Modifies log_prob to conditionally use TensorInvQuadLogdet when use_torch_tensors is enabled and matrix size is suitable for Cholesky |
| gpytorch/settings.py | Adds use_torch_tensors feature flag (default: False) and exports it in __all__ |
| test/functions/test_inv_quad_logdet.py | Adds tests comparing TensorInvQuadLogdet against linear operator implementation for both unbatched and batched cases |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jacobrgardner
left a comment
There was a problem hiding this comment.
This also looks good to me pending the change to psd_safe_cholesky
|
I'd like to also get this one in also, but there are two questions to solve:
|
|
It looks like a good amount of time is spent on Here is a small code snippet reproducing this. It's about 23% slower if we call import torch
n = 10_000
A = torch.randn(n, n) / n
A = A + A.mT + torch.eye(n)
b = torch.randn(n, 1) / n
def solve(A, b, force_contiguous: bool) -> tuple[torch.Tensor, torch.Tensor]:
chol = torch.linalg.cholesky(A, upper=False)
if force_contiguous:
chol = chol.contiguous()
res = torch.linalg.solve_triangular(chol, b, upper=False)
return chol, res
chol, res = solve(A, b, force_contiguous=False)
print(chol.is_contiguous()) # False
print(res.is_contiguous()) # True
chol, res = solve(A, b, force_contiguous=True)
print(chol.is_contiguous()) # True
print(res.is_contiguous()) # True
%timeit -r 20 solve(A, b, force_contiguous=False)
%timeit -r 20 solve(A, b, force_contiguous=True)
# outputs:
546 ms ± 31.1 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)
713 ms ± 17.4 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)Note that I've tried removing the contiguous call in linear operator. And it seems that |
|
It's also not entirely clear to me why we need this contiguous call. It seems that all test cases (both linear operator and gpytorch) are passed without it. cc @Balandat who wrote the comment that this contiguous call is necessary and might know more background on this |
|
Interesting. I think contiguous should only be necessary if it actually crashes without it. @Balandat can you check botorch unit tests? My preference would be to drop the contiguous call if it's unnecessary and removing it leads to 20-40%+ speed ups for exact GPs. Assuming it works, I guess my proposal would be to use tensors by default up to some value (1000? 4000?), and drop contiguous, switching to linear operator. |
|
I looked at the PR where I introduced this I ran the botorch and Ax unit tests on these changes; I didn't run into any failures. That doesn't necessarily mean that it's safe though - let me also kick off some e2e benchmarks to stress test this. |
|
I didn't see anything concerning in the benchmark runs either - I think we should be ok with removing the
That makes sense to me. |
|
@kayween can we re-benchmark with |
|
I ran GP training comparing the following three methods on 4 UCI datasets:
The results are available here https://wandb.ai/kayween/benchmark-gpytorh-2628/reports/PR-2628-Benchmark--VmlldzoxNTU1MDMwMg?accessToken=1f3l6805fm0r1dmpw5ti7ytowmv3bao3xuj1ag3ipirpr5vxfefaiyb6px1w8nkc TL;DR
So I guess this new round of benchmark does not exactly answer the question that (a) how much the linear operator overhead is and (b) where the overhead is located. To fully investigate this, we will need to implement a different version of |
Linear operators are great for large structured matrices. But it might incur overhead for moderate-sized dense matrices. This PR kick-starts exploring the headroom of reducing linear operator overhead.
What's Changed?
Wrap the model training code in the following context manager. Then, the exact log marginal likelihood is computed by doing linear algebra operations directly on torch tensors (as opposed to linear operators).
Under the hood, this PR modifes
MultivariateNormal.log_proband computes the log marginal likelihood by a custominv_quad_logdetimplementation that takes tensors as inputs.The exact GP prediction strategy is not modified yet. Thus, the test-time behavior is not affected.
Benchmark Model Fitting Time
The numbers in the following tables are obtained from this notebook.
The runtime improvement seems consistent. But the runtime reduction varies across different settings. The most significant speed up happens with double precision on CPUs, where the runtime reduction is up to 43% (which might be a bit surprising). Meanwhile, the improvement with single precision on GPUs is less dramatic.
GP Prediction Performance
The MAE of the trained GPs on the synthetic data is virtually the same. So this PR does not seem to regress model performance.