Skip to content

Commit fd4c019

Browse files
Replace deprecated libtorch APIs with torch::linalg equivalents
- THSTensor_cholesky: Replace tensor->cholesky(upper) with torch::linalg_cholesky, using .mH() for upper-triangular results - THSTensor_lu_solve: Replace tensor->lu_solve() with torch::linalg_lu_solve - Fix THSLinalg_lu_solve naming bug (was incorrectly named THSTensor_lu_solve) - Add [Obsolete] attributes to C# wrappers for cholesky() and lu_solve() - Update tests to use linalg.lu_solve instead of deprecated lu_solve
1 parent c79ac52 commit fd4c019

File tree

5 files changed

+17
-5
lines changed

5 files changed

+17
-5
lines changed

src/Native/LibTorchSharp/THSLinearAlgebra.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,16 @@ Tensor THSLinalg_vecdot(const Tensor x, const Tensor y, const int64_t dim, Tenso
351351
CATCH_TENSOR(out == nullptr ? torch::linalg_vecdot(* x, *y, dim) : torch::linalg_vecdot_out(*out, *x, *y, dim))
352352
}
353353

354-
Tensor THSTensor_lu_solve(const Tensor B, const Tensor LU, const Tensor pivots, bool left, bool adjoint, Tensor out)
354+
Tensor THSLinalg_lu_solve(const Tensor B, const Tensor LU, const Tensor pivots, bool left, bool adjoint, Tensor out)
355355
{
356356
CATCH_TENSOR(out == nullptr ? torch::linalg_lu_solve(*LU, *pivots, *B, left, adjoint) : torch::linalg_lu_solve_out(*out, *LU, *pivots, *B, left, adjoint))
357357
}
358358

359359
Tensor THSTensor_cholesky(const Tensor tensor, const bool upper)
360360
{
361-
CATCH_TENSOR(tensor->cholesky(upper))
361+
// torch::cholesky is deprecated in favor of torch::linalg_cholesky.
362+
// linalg_cholesky always returns lower-triangular; use .mH() for upper.
363+
CATCH_TENSOR(upper ? torch::linalg_cholesky(*tensor).mH() : torch::linalg_cholesky(*tensor))
362364
}
363365

364366
Tensor THSTensor_cholesky_inverse(const Tensor tensor, const bool upper)
@@ -441,7 +443,9 @@ Tensor THSTensor_lu(const Tensor tensor, bool pivot, bool get_infos, Tensor* inf
441443

442444
Tensor THSTensor_lu_solve(const Tensor tensor, const Tensor LU_data, const Tensor LU_pivots)
443445
{
444-
CATCH_TENSOR(tensor->lu_solve(*LU_data, *LU_pivots))
446+
// tensor.lu_solve is deprecated in favor of torch::linalg_lu_solve.
447+
// Note: linalg_lu_solve arg order is (LU, pivots, B).
448+
CATCH_TENSOR(torch::linalg_lu_solve(*LU_data, *LU_pivots, *tensor))
445449
}
446450

447451
Tensor THSTensor_lu_unpack(const Tensor LU_data, const Tensor LU_pivots, bool unpack_data, bool unpack_pivots, Tensor* L, Tensor* U)

src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public Tensor tensordot(Tensor b, long dims = 2)
6464
/// </summary>
6565
/// <param name="upper">If upper is true, the returned matrix U is upper-triangular. If upper is false, the returned matrix L is lower-triangular</param>
6666
/// <returns></returns>
67+
[Obsolete("torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future release. Use torch.linalg.cholesky instead.", false)]
6768
public Tensor cholesky(bool upper = false)
6869
{
6970
var res = THSTensor_cholesky(Handle, upper);

src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ public static Tensor addbmm_(Tensor input, Tensor batch1, Tensor batch2, float b
131131

132132
// https://pytorch.org/docs/stable/generated/torch.cholesky
133133

134+
[Obsolete("torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future release. Use torch.linalg.cholesky instead.", false)]
135+
#pragma warning disable CS0618 // Obsolete
134136
public static Tensor cholesky(Tensor input) => input.cholesky();
137+
#pragma warning restore CS0618
135138

136139
// https://pytorch.org/docs/stable/generated/torch.cholesky_inverse
137140
/// <summary>
@@ -250,6 +253,7 @@ public static (Tensor A_LU, Tensor? pivots, Tensor? infos) lu(Tensor A, bool piv
250253
/// The pivots of the LU factorization from torch.lu() of size (∗,m), where *∗ is zero or more batch dimensions.
251254
/// The batch dimensions of LU_pivots must be equal to the batch dimensions of LU_data.</param>
252255
/// <returns></returns>
256+
[Obsolete("torch.lu_solve is deprecated in favor of torch.linalg.lu_solve and will be removed in a future release. Use torch.linalg.lu_solve(LU, pivots, B) instead.", false)]
253257
public static Tensor lu_solve(Tensor b, Tensor LU_data, Tensor LU_pivots)
254258
{
255259
var solution = THSTensor_lu_solve(b.Handle, LU_data.Handle, LU_pivots.Handle);

src/TorchSharp/Tensor/torch.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ public static partial class torch
1616
/// <param name="input">The input matrix</param>
1717
/// <param name="upper">If upper is true, the returned matrix U is upper-triangular. If upper is false, the returned matrix L is lower-triangular</param>
1818
/// <returns></returns>
19+
[Obsolete("torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future release. Use torch.linalg.cholesky instead.", false)]
20+
#pragma warning disable CS0618 // Obsolete
1921
public static Tensor cholesky(Tensor input, bool upper) => input.cholesky(upper);
22+
#pragma warning restore CS0618
2023

2124
/// <summary>
2225
/// Returns the matrix norm or vector norm of a given tensor.

test/TorchSharpTest/LinearAlgebra.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void TestLUSolve()
4949
Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape);
5050
Assert.Equal(new long[] { 2, 3 }, pivots.shape);
5151

52-
var x = lu_solve(b, A_LU, pivots);
52+
var x = linalg.lu_solve(A_LU, pivots, b);
5353
Assert.Equal(new long[] { 2, 3, 1 }, x.shape);
5454

5555
var y = norm(bmm(A, x) - b);
@@ -67,7 +67,7 @@ public void TestLUSolve()
6767
Assert.Equal(new long[] { 2, 3 }, pivots.shape);
6868
Assert.Equal(new long[] { 2 }, infos.shape);
6969

70-
var x = lu_solve(b, A_LU, pivots);
70+
var x = linalg.lu_solve(A_LU, pivots, b);
7171
Assert.Equal(new long[] { 2, 3, 1 }, x.shape);
7272

7373
var y = norm(bmm(A, x) - b);

0 commit comments

Comments
 (0)