Skip to content

Commit 9261a26

Browse files
committed
Replace deprecated torch APIs with modern torch.linalg equivalents
torch.inverse, torch.pinverse, and torch.norm have been deprecated since PyTorch 1.9. This updates all usage to their modern replacements and, critically, registers torch.linalg.inv for __torch_function__ dispatch so that torch.linalg.inv(linear_op) works correctly. Changes: - Register torch.linalg.inv alongside torch.inverse for LinearOperator dispatch (fixes torch.linalg.inv not working on LinearOperators) - Replace torch.pinverse() with torch.linalg.pinv() - Replace torch.norm() with torch.linalg.vector_norm() (source files) and torch.linalg.norm() (test files) - Update stale comments referencing torch.cholesky, torch.solve, torch.symeig, and torch.eig to their modern equivalents
1 parent cdebc6e commit 9261a26

17 files changed

Lines changed: 72 additions & 52 deletions

examples/LinearOperator_demo.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@
220220
"source": [
221221
"#### Eigendecomposition\n",
222222
"\n",
223-
"This uses `__torch_function__` in order to dispatch `torch.symeig` to a custom implementation that essentially just returns the diagonal elements and the identity matrix (should sort the evals and permute the evecs to have the exact same behavior, that's an easy thing to do).\n",
223+
"This uses `__torch_function__` in order to dispatch `torch.linalg.eigh` to a custom implementation that essentially just returns the diagonal elements and the identity matrix (should sort the evals and permute the evecs to have the exact same behavior, that's an easy thing to do).\n",
224224
"\n",
225225
"Time complexity goes from $\\mathcal O(n^3)$ to $\\mathcal O(1)$ (without sorting). Memory complexity goes from $\\mathcal O(n^2)$ to $\\mathcal O(n)$. \n",
226226
"\n",
@@ -858,8 +858,8 @@
858858
"metadata": {},
859859
"outputs": [],
860860
"source": [
861-
"tri_inv = torch.inverse(tri)\n",
862-
"tri_lo_inv = tri_lo.inverse() # TODO: Handle in torch.inverse by registering via __torch_function__\n",
861+
"tri_inv = torch.linalg.inv(tri)\n",
862+
"tri_lo_inv = torch.linalg.inv(tri_lo)\n",
863863
"\n",
864864
"assert torch.allclose(tri_inv, tri_lo_inv.to_dense())"
865865
]
@@ -879,7 +879,7 @@
879879
}
880880
],
881881
"source": [
882-
"t_d = %timeit -o torch.inverse(tri)"
882+
"t_d = %timeit -o torch.linalg.inv(tri)"
883883
]
884884
},
885885
{

linear_operator/functions/_inv_quad_logdet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def forward(
106106
else:
107107
probe_vectors = precond_lt.zero_mean_mvn_samples(num_random_probes)
108108
probe_vectors = probe_vectors.unsqueeze(-2).transpose(0, -2).squeeze(0).mT.contiguous()
109-
probe_vector_norms = torch.norm(probe_vectors, p=2, dim=-2, keepdim=True)
109+
probe_vector_norms = torch.linalg.vector_norm(probe_vectors, ord=2, dim=-2, keepdim=True)
110110
probe_vectors = probe_vectors.div(probe_vector_norms)
111111

112112
# Probe vectors

linear_operator/functions/_pivoted_cholesky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def forward(ctx, representation_tree, max_iter, error_tol, *matrix_args):
4141
device=matrix.device,
4242
)
4343
orig_error = torch.max(matrix_diag, dim=-1)[0]
44-
errors = torch.norm(matrix_diag, 1, dim=-1) / orig_error
44+
errors = torch.linalg.vector_norm(matrix_diag, ord=1, dim=-1) / orig_error
4545

4646
# The permutation
4747
permutation = torch.arange(0, matrix_shape[-1], dtype=torch.long, device=matrix_diag.device)
@@ -96,7 +96,7 @@ def forward(ctx, representation_tree, max_iter, error_tol, *matrix_args):
9696
L[..., m, :] = L_m
9797

9898
# Keep track of errors - for potential early stopping
99-
errors = torch.norm(matrix_diag.gather(-1, pi_i), 1, dim=-1) / orig_error
99+
errors = torch.linalg.vector_norm(matrix_diag.gather(-1, pi_i), ord=1, dim=-1) / orig_error
100100

101101
m = m + 1
102102

linear_operator/operators/_linear_operator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _implements_second_arg(torch_function: Callable) -> Callable:
8080
where the first argument of the function is a torch.Tensor and the
8181
second argument is a LinearOperator
8282
83-
Examples of this include :meth:`torch.cholesky_solve`, `torch.solve`, or `torch.matmul`.
83+
Examples of this include :meth:`torch.cholesky_solve`, `torch.linalg.solve`, or `torch.matmul`.
8484
"""
8585

8686
@functools.wraps(torch_function)
@@ -1803,13 +1803,17 @@ def inv_quad_logdet(
18031803
inv_quad_term = inv_quad_term.sum(-1)
18041804
return inv_quad_term, logdet_term
18051805

1806+
@_implements(torch.linalg.inv)
18061807
@_implements(torch.inverse)
18071808
def inverse(
18081809
self: LinearOperator, # shape: (*batch, N, N)
18091810
) -> LinearOperator: # shape: (*batch, N, N)
18101811
# Only implemented by some LinearOperator subclasses
1811-
# We define it here so that we can map the torch function torch.inverse to the LinearOperator method
1812-
raise NotImplementedError(f"torch.inverse({self.__class__.__name__}) is not implemented.")
1812+
# We define it here so that we can map torch.linalg.inv / torch.inverse to the LinearOperator method
1813+
raise NotImplementedError(
1814+
f"torch.linalg.inv({self.__class__.__name__}) is not implemented. "
1815+
"The LinearOperator subclass must implement the `inverse` method."
1816+
)
18131817

18141818
@property
18151819
def is_square(self) -> bool:
@@ -2296,7 +2300,7 @@ def root_inv_decomposition(
22962300
elif method == "pinverse":
22972301
# this is numerically unstable and should rarely be used
22982302
root = self.root_decomposition().root.to_dense()
2299-
inv_root = torch.pinverse(root).mT
2303+
inv_root = torch.linalg.pinv(root).mT
23002304
else:
23012305
raise RuntimeError(f"Unknown root inv decomposition method '{method}'")
23022306

linear_operator/operators/kronecker_product_added_diag_linear_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _logdet(
110110

111111
else:
112112
# we use the same matrix determinant identity: |K + D| = |D| |I + D^{-1}K|
113-
# but have to symmetrize the second matrix because torch.eig may not be
113+
# but have to symmetrize the second matrix because torch.linalg.eig may not be
114114
# completely differentiable.
115115
lt = self.linear_op
116116
dlt = self.diag_tensor

linear_operator/utils/cholesky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=None):
5656
:attr:`A` (Tensor):
5757
The tensor to compute the Cholesky decomposition of
5858
:attr:`upper` (bool, optional):
59-
See torch.cholesky
59+
See torch.linalg.cholesky
6060
:attr:`out` (Tensor, optional):
61-
See torch.cholesky
61+
See torch.linalg.cholesky
6262
:attr:`jitter` (float, optional):
6363
The jitter to add to the diagonal of A in case A is only p.s.d. If omitted,
6464
uses settings.cholesky_jitter.value()

linear_operator/utils/lanczos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def lanczos_tridiag(
7878

7979
# Begin algorithm
8080
# Initial Q vector: q_0_vec
81-
q_0_vec = init_vecs / torch.norm(init_vecs, 2, dim=dim_dimension).unsqueeze(dim_dimension)
81+
q_0_vec = init_vecs / torch.linalg.vector_norm(init_vecs, ord=2, dim=dim_dimension).unsqueeze(dim_dimension)
8282
q_mat[0].copy_(q_0_vec)
8383

8484
# Initial alpha value: alpha_0
@@ -87,7 +87,7 @@ def lanczos_tridiag(
8787

8888
# Initial beta value: beta_0
8989
r_vec.sub_(alpha_0.unsqueeze(dim_dimension).mul(q_0_vec))
90-
beta_0 = torch.norm(r_vec, 2, dim=dim_dimension)
90+
beta_0 = torch.linalg.vector_norm(r_vec, ord=2, dim=dim_dimension)
9191

9292
# Copy over alpha_0 and beta_0 to t_mat
9393
t_mat[0, 0].copy_(alpha_0)
@@ -118,7 +118,7 @@ def lanczos_tridiag(
118118
correction = r_vec.unsqueeze(0).mul(q_mat[: k + 1]).sum(dim_dimension, keepdim=True)
119119
correction = q_mat[: k + 1].mul(correction).sum(0)
120120
r_vec.sub_(correction)
121-
r_vec_norm = torch.norm(r_vec, 2, dim=dim_dimension, keepdim=True)
121+
r_vec_norm = torch.linalg.vector_norm(r_vec, ord=2, dim=dim_dimension, keepdim=True)
122122
r_vec.div_(r_vec_norm)
123123

124124
# Get next beta value
@@ -137,7 +137,7 @@ def lanczos_tridiag(
137137
correction = r_vec.unsqueeze(0).mul(q_mat[: k + 1]).sum(dim_dimension, keepdim=True)
138138
correction = q_mat[: k + 1].mul(correction).sum(0)
139139
r_vec.sub_(correction)
140-
r_vec_norm = torch.norm(r_vec, 2, dim=dim_dimension, keepdim=True)
140+
r_vec_norm = torch.linalg.vector_norm(r_vec, ord=2, dim=dim_dimension, keepdim=True)
141141
r_vec.div_(r_vec_norm)
142142
inner_products = q_mat[: k + 1].mul(r_vec.unsqueeze(0)).sum(dim_dimension)
143143

linear_operator/utils/linear_cg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def linear_cg(
296296
curr_conjugate_vec,
297297
)
298298

299-
torch.norm(residual, 2, dim=-2, keepdim=True, out=residual_norm)
299+
torch.linalg.vector_norm(residual, ord=2, dim=-2, keepdim=True, out=residual_norm)
300300
residual_norm.masked_fill_(rhs_is_zero, 0)
301301
torch.lt(residual_norm, stop_updating_after, out=has_converged)
302302

linear_operator/utils/minres.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def minres(
182182

183183
# Check convergence criterion
184184
if (i + 1) % 10 == 0:
185-
torch.norm(search_update, dim=-2, out=search_update_norm)
186-
torch.norm(solution, dim=-2, out=solution_norm)
185+
torch.linalg.vector_norm(search_update, dim=-2, out=search_update_norm)
186+
torch.linalg.vector_norm(solution, dim=-2, out=solution_norm)
187187
conv = search_update_norm.div_(solution_norm).mean().item()
188188
if conv < settings.minres_tolerance.value():
189189
break

test/functions/test_dsmm.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_forward(self):
1616

1717
res = linear_operator.dsmm(sparse, dense)
1818
actual = torch.mm(sparse.to_dense(), dense)
19-
self.assertLess(torch.norm(res - actual), 1e-5)
19+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
2020

2121
def test_forward_batch(self):
2222
i = torch.tensor(
@@ -29,7 +29,7 @@ def test_forward_batch(self):
2929

3030
res = linear_operator.dsmm(sparse, dense)
3131
actual = torch.matmul(sparse.to_dense(), dense)
32-
self.assertLess(torch.norm(res - actual), 1e-5)
32+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
3333

3434
def test_forward_multi_batch(self):
3535
i = torch.tensor(
@@ -47,7 +47,7 @@ def test_forward_multi_batch(self):
4747

4848
res = linear_operator.dsmm(sparse, dense)
4949
actual = torch.matmul(sparse.to_dense(), dense)
50-
self.assertLess(torch.norm(res - actual), 1e-5)
50+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
5151

5252
def test_backward(self):
5353
i = torch.tensor([[0, 1, 1], [2, 0, 2]], dtype=torch.long)
@@ -61,7 +61,7 @@ def test_backward(self):
6161
res.backward(grad_output)
6262
actual = torch.mm(sparse.to_dense(), dense_copy)
6363
actual.backward(grad_output)
64-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
64+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
6565

6666
def test_backward_batch(self):
6767
i = torch.tensor(
@@ -78,7 +78,7 @@ def test_backward_batch(self):
7878
res.backward(grad_output)
7979
actual = torch.matmul(sparse.to_dense(), dense_copy)
8080
actual.backward(grad_output)
81-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
81+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
8282

8383
def test_backward_multi_batch(self):
8484
i = torch.tensor(
@@ -100,7 +100,7 @@ def test_backward_multi_batch(self):
100100
res.backward(grad_output)
101101
actual = torch.matmul(sparse.to_dense(), dense_copy)
102102
actual.backward(grad_output)
103-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
103+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
104104

105105
def test_broadcast_rhs(self):
106106
i = torch.tensor([[0, 1, 1, 0, 1, 1], [2, 0, 2, 2, 0, 2]], dtype=torch.long)
@@ -111,12 +111,12 @@ def test_broadcast_rhs(self):
111111

112112
res = linear_operator.dsmm(sparse, dense)
113113
actual = torch.matmul(sparse.to_dense(), dense_copy)
114-
self.assertLess(torch.norm(res - actual), 1e-5)
114+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
115115

116116
grad_output = torch.randn(4, 2, 2, 4)
117117
res.backward(grad_output)
118118
actual.backward(grad_output)
119-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
119+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
120120

121121
i = torch.tensor(
122122
[[0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1], [2, 0, 2, 2, 0, 2]],
@@ -129,12 +129,12 @@ def test_broadcast_rhs(self):
129129

130130
res = linear_operator.dsmm(sparse, dense)
131131
actual = torch.matmul(sparse.to_dense(), dense_copy)
132-
self.assertLess(torch.norm(res - actual), 1e-5)
132+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
133133

134134
grad_output = torch.randn(4, 2, 2, 4)
135135
res.backward(grad_output)
136136
actual.backward(grad_output)
137-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
137+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
138138

139139
def test_broadcast_sparse(self):
140140
i = torch.tensor(
@@ -148,12 +148,12 @@ def test_broadcast_sparse(self):
148148

149149
res = linear_operator.dsmm(sparse, dense)
150150
actual = torch.matmul(sparse.to_dense(), dense_copy)
151-
self.assertLess(torch.norm(res - actual), 1e-5)
151+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
152152

153153
grad_output = torch.randn(2, 2, 4)
154154
res.backward(grad_output)
155155
actual.backward(grad_output)
156-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
156+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
157157

158158
def test_broadcast_singleton(self):
159159
i = torch.tensor(
@@ -167,12 +167,12 @@ def test_broadcast_singleton(self):
167167

168168
res = linear_operator.dsmm(sparse, dense)
169169
actual = torch.matmul(sparse.to_dense(), dense_copy)
170-
self.assertLess(torch.norm(res - actual), 1e-5)
170+
self.assertLess(torch.linalg.norm(res - actual), 1e-5)
171171

172172
grad_output = torch.randn(2, 2, 4)
173173
res.backward(grad_output)
174174
actual.backward(grad_output)
175-
self.assertLess(torch.norm(dense.grad - dense_copy.grad).item(), 1e-5)
175+
self.assertLess(torch.linalg.norm(dense.grad - dense_copy.grad).item(), 1e-5)
176176

177177

178178
if __name__ == "__main__":

0 commit comments

Comments
 (0)