You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Replace deprecated torch APIs with modern torch.linalg equivalents
torch.inverse, torch.pinverse, and torch.norm have been deprecated
since PyTorch 1.9. This updates all usage to their modern replacements
and, critically, registers torch.linalg.inv for __torch_function__
dispatch so that torch.linalg.inv(linear_op) works correctly.
Changes:
- Register torch.linalg.inv alongside torch.inverse for LinearOperator
dispatch (fixes torch.linalg.inv not working on LinearOperators)
- Replace torch.pinverse() with torch.linalg.pinv()
- Replace torch.norm() with torch.linalg.vector_norm() (source files)
and torch.linalg.norm() (test files)
- Update stale comments referencing torch.cholesky, torch.solve,
torch.symeig, and torch.eig to their modern equivalents
Copy file name to clipboardExpand all lines: examples/LinearOperator_demo.ipynb
+4-4Lines changed: 4 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -220,7 +220,7 @@
220
220
"source": [
221
221
"#### Eigendecomposition\n",
222
222
"\n",
223
-
"This uses `__torch_function__` in order to dispatch `torch.symeig` to a custom implementation that essentially just returns the diagonal elements and the identity matrix (should sort the evals and permute the evecs to have the exact same behavior, that's an easy thing to do).\n",
223
+
"This uses `__torch_function__` in order to dispatch `torch.linalg.eigh` to a custom implementation that essentially just returns the diagonal elements and the identity matrix (should sort the evals and permute the evecs to have the exact same behavior, that's an easy thing to do).\n",
224
224
"\n",
225
225
"Time complexity goes from $\\mathcal O(n^3)$ to $\\mathcal O(1)$ (without sorting). Memory complexity goes from $\\mathcal O(n^2)$ to $\\mathcal O(n)$. \n",
226
226
"\n",
@@ -858,8 +858,8 @@
858
858
"metadata": {},
859
859
"outputs": [],
860
860
"source": [
861
-
"tri_inv = torch.inverse(tri)\n",
862
-
"tri_lo_inv = tri_lo.inverse() # TODO: Handle in torch.inverse by registering via __torch_function__\n",
0 commit comments