Skip to content

reduce linear operator overhead in exact marginal log likelihood computation#2682

Draft
kayween wants to merge 6 commits intocornellius-gp:mainfrom
kayween:linop-overhead-exact-mll
Draft

reduce linear operator overhead in exact marginal log likelihood computation#2682
kayween wants to merge 6 commits intocornellius-gp:mainfrom
kayween:linop-overhead-exact-mll

Conversation

@kayween
Copy link
Copy Markdown
Collaborator

@kayween kayween commented Nov 19, 2025

Linear operators are great for large structured matrices. But it might incur overhead for moderate-sized dense matrices. This PR kick-starts exploring the headroom of reducing linear operator overhead.

What's Changed?

Wrap the model training code in the following context manager. Then, the exact log marginal likelihood is computed by doing linear algebra operations directly on torch tensors (as opposed to linear operators).

with settings.use_torch_tensors(True):
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()

Under the hood, this PR modifes MultivariateNormal.log_prob and computes the log marginal likelihood by a custom inv_quad_logdet implementation that takes tensors as inputs.

The exact GP prediction strategy is not modified yet. Thus, the test-time behavior is not affected.

Benchmark Model Fitting Time

The numbers in the following tables are obtained from this notebook.

  • All GP models are configured to use Cholesky decomposition for inference. No CG is involved.
  • The runtime in each table cell is timed over 200 replications.

The runtime improvement seems consistent. But the runtime reduction varies across different settings. The most significant speed up happens with double precision on CPUs, where the runtime reduction is up to 43% (which might be a bit surprising). Meanwhile, the improvement with single precision on GPUs is less dramatic.

model fitting running time (float32, CPU)
dataset size501005001000
main branch (s)13.815.344.6169.8
this PR (s)11.013.332.5110.8
runtime reduction-20.3%-13.1%-27.1%-34.7%
model fitting running time (float64, CPU)
dataset size501005001000
main branch (s)14.616.272.3342.4
this PR (s)12.314.148.4192.4
runtime reduction-15.8%-13.0%-33.1%-43.8%
model fitting running time (float32, GPU)
dataset size501005001000
main branch (s)19.417.219.827.3
this PR (s)17.216.118.323.6
runtime reduction-11.3%-6.4%-7.6%-13.6%
model fitting running time (float64, GPU)
dataset size501005001000
main branch (s)19.422.465.1184.9
this PR (s)17.420.256.5143.5
runtime reduction-10.3%-9.8%-13.2%-22.4%

GP Prediction Performance

The MAE of the trained GPs on the synthetic data is virtually the same. So this PR does not seem to regress model performance.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Nov 19, 2025

cc @jacobrgardner @gpleiss @Balandat

@Balandat
Copy link
Copy Markdown
Collaborator

cc @saitcakmak who has done some investigations in the past into removing LinearOperator overhead. I'll let him give the details, but IIRC the high level tl;dr was that while we did see speedups those didn't amount to dramatic savings in the grand scheme of things for our typical use cases. But our use cases aren't necessarily standard so I think there could be meaningful value in allowing to short-circuit the linear operator overhead in the small data regime.

@saitcakmak
Copy link
Copy Markdown
Collaborator

I think this is great! What I had done was to take out linear operator and any parts of GPyTorch that we didn't need for ExactGPs and create a bare-bone version of the library (~20% of GPyTorch and no linear operator). I had seen up to 2x faster execution for some model operations, but at the end of the day, it wasn't that significant when considered as part of the whole BO loop in Ax. I think introducing small (depending on where we measure from) improvements like this to core GPyTorch is great. I'd definitely use this for ExactGPs in BoTorch.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Nov 20, 2025

I had a question while implementing this PR, which I think worth some discussions. Note that CholLinearOperator does not implement a custom backward pass for the log determinant. Instead, it relies on PyTorch's default backward pass, which backprop through the Cholesky decomposition. Were there any special considerations there?

This PR does implement a custom backward pass for log determinant by d logdet(K) = K^{-1}, which involves two triangular solves with the cached Cholesky factor of K.

In terms of the running time, this custom backward pass for logdet is indeed faster than PyTorch's default backward pass. Numerically, computing the inverse by two triangular solves shouldn't be bad since the backward pass of Cholesky decomposition requires two triangular solves anyway.

(I am assuming PyTorch implements the Cholesky backward pass by something similar to Eq (9) in Iain Murray's note.)

@Balandat
Copy link
Copy Markdown
Collaborator

I am not sure if there were any special considerations at the time - @gpleiss or @jacobrgardner might know?

@jacobrgardner
Copy link
Copy Markdown
Member

I dug around and the only remotely related issue to cholesky derivatives I could find was this one pytorch/pytorch#18825.

I think we just assumed the pytorch default derivatives for these ops would be fast. I certainly have no problem with us merging a custom backward pass in the name of 10-40% speed ups.

Comment thread gpytorch/functions/inv_quad_logdet.py Outdated
Comment thread gpytorch/settings.py
_default = True


class use_torch_tensors(_feature_flag):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about making this on by default up to some N?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference.

I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N.

@jacobrgardner
Copy link
Copy Markdown
Member

@kayween One thing I'm noticing: the speedup is increasing with matrix size! That is unexpected: one would assume that the additional overhead of linear operator packaging becomes negligible as the matrix size increases.

Can we extend the benchmark out to larger N? If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Nov 22, 2025

I dug around and the only remotely related issue to Cholesky derivatives I could find was this one pytorch/pytorch#18825.

I dug along this line a bit. @jacobrgardner and anyone who is curious, this is how the current PyTorch implements the Cholesky backward pass, which involves two triangular solves and a matmul.

https://github.com/pytorch/pytorch/blob/a3cc252e03572835c15afde54b81fc5e8616ad27/torch/csrc/autograd/FunctionsManual.cpp#L1984-L2010

I haven't dug into cholesky_inverse too deep since it is a part of LAPACK. But I assume it's based on two triangular solves. Thus, the custom backward pass of logdet in this PR saves (at least) a matmul compared to backproping through Cholesky decomposition.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Nov 25, 2025

@jacobrgardner Here are more benchmark results for larger N. The code generating the results is available in this gist. As with before, all benchmark runs use Cholesky decomposition for GP training.

Running time (in seconds) on synthetic datasets

dataset sizes CPU f32 CPU f64 GPU f32 GPU f64
2000 5.2 → 2.3 (-55.8%) 9.5 → 4.4 (-53.7%) 4.5 → 3.1 (-31.1%) 4.4 → 3.1 (-29.5%)
5000 57.7 → 24.3 (-57.9%) 110.3 → 47.2 (-57.2%) 2.3 → 1.8 (-21.7%) 57.7 → 34.2 (-40.7%)
10000 362.8 → 138.6 (-61.8%) 739.2 → 288.1 (-61.0%) 13.6 → 9.5 (-30.1%) 442.3 → 254.8 (-42.4%)

The time reduction doesn't seem to saturate as we increase the size of training data! For example, we're seeing about 61% time reduction on CPU when n = 10000.

If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops.

I don't think the issue is necessarily with PyTorch ops. There are probably two explanations why the speed-up does not stop with larger dataset sizes.

  1. My hunch is that linear operators' representation tree implementation is not very efficient compared to PyTorch ops, which is where the overhead comes from. For example, InvQuad needs to reconstruct those linear operators in the forward and backward passes.
  2. This PR has a custom backward pass for logdet, which is not available in linear operators.

cc @gpleiss @Balandat @saitcakmak

@jacobrgardner
Copy link
Copy Markdown
Member

My hunch is that linear operators' representation tree implementation is not very efficient compared to PyTorch ops, which is where the overhead comes from. For example, InvQuad needs to reconstruct those linear operators in the forward and backward passes.

@kayween this doesn't make sense to me. The representation tree at the end of the day should just be carrying around the underlying pytorch tensors as pointers and reconstructing the thin python wrapper.

Like, in a vanilla GP, we have an AddedDiagLinearOperator(DenseLinearOperator(K), DiagLinearOperator(\sigma^2)). The representation tree for this has, at its roots, tensor(K) and tensor(\sigma^2). The pytorch Function takes those tensors as inputs, and just reconstructs AddedDiagLinearOperator(...) in the forward pass of the Function.

All that is to say: there's no way that python overhead, which should essentially be constant, dominates the linear algebra being done by n = 10,000. Or, more importantly, that overhead certainly shouldn't be growing with n!

This PR has a custom backward pass for logdet, which is not available in linear operators.

Sure, this makes sense to me.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Nov 26, 2025

@jacobrgardner This following code snippet shows that linear operators still have about 30% overhead compared to PyTorch even at the scale of n = 10000. This overhead may not come from the representation tree, but it seems like it has to be somewhere inside the inv_quad function call.

import linear_operator
import torch

from gpytorch.kernels import RBFKernel

torch.manual_seed(42)

n = 10_000

train_x = torch.rand(n, 5)
train_y = torch.rand(n)


def compute_inv_quad(use_linop: bool):
    covar_module = RBFKernel()
    covar = covar_module(train_x).evaluate_kernel().add_jitter(1e-3)

    if use_linop:
        with linear_operator.settings.fast_computations(False, False, False):
            return covar.inv_quad(inv_quad_rhs=train_y.unsqueeze(-1))
    else:
        covar = covar.to_dense()
        chol = torch.linalg.cholesky(covar)

        inv_chol_rhs = torch.linalg.solve_triangular(chol, train_y.unsqueeze(-1), upper=False).squeeze(-1)
        return (inv_chol_rhs**2).sum(-1)


%timeit compute_inv_quad(use_linop=True)
%timeit compute_inv_quad(use_linop=False)

# outputs:
# 1.31 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 912 ms ± 2.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Copilot AI review requested due to automatic review settings December 24, 2025 03:07
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a performance optimization for exact marginal log likelihood computation by reducing linear operator overhead when working with moderate-sized dense matrices. The implementation allows users to opt into using direct tensor operations via a new use_torch_tensors setting, achieving runtime reductions of 6-44% across different hardware and precision configurations.

Key Changes:

  • Introduces a custom TensorInvQuadLogdet autograd function that computes inverse quadratic forms and log determinants directly on torch tensors using Cholesky decomposition
  • Adds a use_torch_tensors feature flag to enable the optimization
  • Modifies MultivariateNormal.log_prob to use the tensor-based implementation when the flag is enabled and Cholesky decomposition is appropriate

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
gpytorch/functions/inv_quad_logdet.py Implements TensorInvQuadLogdet, a custom autograd function for computing inverse quadratic forms and log determinants on tensors, bypassing linear operator overhead
gpytorch/functions/init.py Exports the new TensorInvQuadLogdet function
gpytorch/distributions/multivariate_normal.py Modifies log_prob to conditionally use TensorInvQuadLogdet when use_torch_tensors is enabled and matrix size is suitable for Cholesky
gpytorch/settings.py Adds use_torch_tensors feature flag (default: False) and exports it in __all__
test/functions/test_inv_quad_logdet.py Adds tests comparing TensorInvQuadLogdet against linear operator implementation for both unbatched and batched cases

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread test/functions/test_inv_quad_logdet.py Outdated
Comment thread test/functions/test_inv_quad_logdet.py Outdated
Comment thread test/functions/test_inv_quad_logdet.py Outdated
Comment thread test/functions/test_inv_quad_logdet.py
Comment thread gpytorch/functions/inv_quad_logdet.py
Comment thread test/functions/test_inv_quad_logdet.py
Comment thread test/functions/test_inv_quad_logdet.py
Comment thread gpytorch/distributions/multivariate_normal.py Outdated
Comment thread gpytorch/functions/inv_quad_logdet.py Outdated
Comment thread gpytorch/functions/inv_quad_logdet.py
Copy link
Copy Markdown
Member

@jacobrgardner jacobrgardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks good to me pending the change to psd_safe_cholesky

@kayween kayween marked this pull request as draft December 28, 2025 06:40
@jacobrgardner
Copy link
Copy Markdown
Member

I'd like to also get this one in also, but there are two questions to solve:

  1. (Relatively easy) What is a good default value for avoiding linear operator?

  2. (Harder, more important) The increasing speed up with n does not make intuitive sense, since the overhead due to linear operator should be constant in the simple GP setting. This makes me think the speed here may be less about avoiding linear operator, and more about some operation in inv_quad that is being really bad, indicating potential for very large speed ups when using linear operator.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Dec 30, 2025

It looks like a good amount of time is spent on .contiguous, which could explain the overhead in inv_quad.

https://github.com/cornellius-gp/linear_operator/blob/210187777d882a71f511854f3eacd663a3f9c85b/linear_operator/operators/_linear_operator.py#L527-L528

Here is a small code snippet reproducing this. It's about 23% slower if we call .contiguous after Cholesky.

import torch

n = 10_000

A = torch.randn(n, n) / n
A = A + A.mT + torch.eye(n)

b = torch.randn(n, 1) / n


def solve(A, b, force_contiguous: bool) -> tuple[torch.Tensor, torch.Tensor]:
    chol = torch.linalg.cholesky(A, upper=False)

    if force_contiguous:
        chol = chol.contiguous()

    res = torch.linalg.solve_triangular(chol, b, upper=False)

    return chol, res

chol, res = solve(A, b, force_contiguous=False)
print(chol.is_contiguous())  # False
print(res.is_contiguous())  # True

chol, res = solve(A, b, force_contiguous=True)
print(chol.is_contiguous())  # True
print(res.is_contiguous())  # True

%timeit -r 20 solve(A, b, force_contiguous=False)
%timeit -r 20 solve(A, b, force_contiguous=True)

# outputs:
546 ms ± 31.1 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)
713 ms ± 17.4 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)

Note that torch.linalg.cholesky returns a non-contiguous tensor (somehow?) no matter upper is true or not. Making the Cholesky factor contiguous would require CPU/GPU to allocate new memory, which could incur some overhead. This overhead does not seem to be constant and it intuitively increases as n increases.

I've tried removing the contiguous call in linear operator. And it seems that inv_quad becomes almost as fast as doing linear algebra manually on tensors if we get rid of this contiguous call.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Dec 30, 2025

It's also not entirely clear to me why we need this contiguous call. It seems that all test cases (both linear operator and gpytorch) are passed without it.

https://github.com/cornellius-gp/linear_operator/blob/210187777d882a71f511854f3eacd663a3f9c85b/linear_operator/operators/_linear_operator.py#L527-L528

cc @Balandat who wrote the comment that this contiguous call is necessary and might know more background on this

@jacobrgardner
Copy link
Copy Markdown
Member

Interesting. I think contiguous should only be necessary if it actually crashes without it.

@Balandat can you check botorch unit tests? My preference would be to drop the contiguous call if it's unnecessary and removing it leads to 20-40%+ speed ups for exact GPs.

Assuming it works, I guess my proposal would be to use tensors by default up to some value (1000? 4000?), and drop contiguous, switching to linear operator.

@Balandat
Copy link
Copy Markdown
Collaborator

Balandat commented Jan 4, 2026

I looked at the PR where I introduced this .contiguous() call - I didn't clearly document why I though this was necessary at the time - but that was seven years ago so it's very much possible that changes to pytorch / gpytorch / botorch have rendered this unnecessary since.

I ran the botorch and Ax unit tests on these changes; I didn't run into any failures. That doesn't necessarily mean that it's safe though - let me also kick off some e2e benchmarks to stress test this.

@Balandat
Copy link
Copy Markdown
Collaborator

Balandat commented Jan 4, 2026

I didn't see anything concerning in the benchmark runs either - I think we should be ok with removing the .contiguous() call.

Assuming it works, I guess my proposal would be to use tensors by default up to some value (1000? 4000?), and drop contiguous, switching to linear operator.

That makes sense to me.

@jacobrgardner
Copy link
Copy Markdown
Member

@kayween can we re-benchmark with .contiguous() dropped? If pure tensors still gains on linear operator over time then there's likely still something to debug (linear_operator should introduce constant overhead since the representation tree doesn't actually copy data around). If the gap closes over time, we can just pick a threshold at which we decide the overhead is negligible and use linear_operator from then on.

@kayween
Copy link
Copy Markdown
Collaborator Author

kayween commented Jan 6, 2026

I ran GP training comparing the following three methods on 4 UCI datasets:

  1. The linear operator main branch
  2. Remove the contiguous call in linear operator
  3. This PR

The results are available here https://wandb.ai/kayween/benchmark-gpytorh-2628/reports/PR-2628-Benchmark--VmlldzoxNTU1MDMwMg?accessToken=1f3l6805fm0r1dmpw5ti7ytowmv3bao3xuj1ag3ipirpr5vxfefaiyb6px1w8nkc

TL;DR

  1. This PR is still much faster than linear operator even with .contiguous() removed. This is because this PR uses a custom backward pass for the log determinant: d logdet(K) = K^{-1}. In contrast, the linear operator (specifically CholLinearOperator) backprops through the Cholesky decomposition, which is not as efficient. The custom backward pass in this PR actually saves FLOPs (not just reducing the overhead).
  2. Perhaps surprisingly, removing the contiguous call only reducing the training time marginally. It actually makes sense now that I think about it. Because GP training spends more time in the backward pass than the forward pass. Removing the contiguous call, while reducing the overhead in the forward pass, does not affect the overall training time much. Indeed, I observed that the backward pass could takes 5x more time than the forward pass when training on 10000 data points on CPUs.

So I guess this new round of benchmark does not exactly answer the question that (a) how much the linear operator overhead is and (b) where the overhead is located. To fully investigate this, we will need to implement a different version of inv_quad_logdet that uses PyTorch's Cholesky backward for the logdet term.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants