Skip to content

Commit 0ec523a

Browse files
committed
added RSD for symmetric tensors with and without min_blockdim
1 parent a9ecd68 commit 0ec523a

4 files changed

Lines changed: 499 additions & 11 deletions

File tree

include/linalg.hpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ namespace cytnx {
853853
// Rsvd:
854854
//==================================================
855855
/**
856-
* @brief Perform a truncated Singular-Value decomposition of a UniTensor.
857-
@details This function will perform a truncated Singular-Value decomposition
856+
* @brief Perform a randomized truncated Singular-Value decomposition of a UniTensor.
857+
@details This function will perform a randomized truncated Singular-Value decomposition
858858
of a UniTensor. It uses the ?gesvd method for the SVD. It will perform the randomized
859859
SVD first, and then truncate the singular values to the given cutoff \p err. That means, given a
860860
UniTensor \p Tin as \f$ M \f$, then the result will be: \f[ M = U S V^\dagger, \f] where \f$ S
@@ -908,6 +908,41 @@ namespace cytnx {
908908
cytnx_uint64 power_iteration = 0,
909909
unsigned int seed = random::__static_random_device());
910910

911+
/**
912+
* @brief Perform a truncated Singular-Value decomposition of a UniTensor and keep at most
913+
* \p min_blockdim singular values in each block.
914+
* @details For each block, the minimum dimension can be chosen for the truncation. This can be
915+
* helpful to avoid loosing symmetry sectors in the truncated SVD. For more details, please
916+
* refer to the documentation of the function \ref Rsvd(const cytnx::UniTensor &Tin,
917+
* cytnx_uint64 keepdim, double err = 0., bool is_U = true, bool is_vT = true, unsigned int
918+
* return_err = 0, cytnx_uint64 mindim = 1, cytnx_uint64 oversampling_summand = 10, double
919+
* oversampling_factor = 1., cytnx_uint64 power_iteration = 0, unsigned int seed =
920+
* random::__static_random_device()).
921+
*
922+
* The truncation order is as following (later constraints might be violated by previous
923+
* ones):<br> 1) Keep the largest \p min_blockdim singular values in each block; reduce \p
924+
* keepdim and \p mindim by the number of already kept singular values<br> 2) Keep at most \p
925+
* keepdim singular values; there might be an exception in case of exact degeneracies where more
926+
* singular values are kept<br> 3) Keep at least \p mindim singular values;<br> 4) Drop all
927+
* singular values smaller than \p err (no normalization applied to the singular values)
928+
*
929+
* @param[in] min_blockdim a vector containing the minimum dimension of each block;
930+
* alternatively, a vector with only one element can be given to have the same min_blockdim for
931+
* each block
932+
* @see Rsvd(const cytnx::UniTensor &Tin, cytnx_uint64 keepdim, double err = 0., bool is_U =
933+
* true, bool is_vT = true, unsigned int return_err = 0, cytnx_uint64 mindim = 1, cytnx_uint64
934+
* oversampling_summand = 10, double oversampling_factor = 1., cytnx_uint64 power_iteration = 0,
935+
* unsigned int seed = random::__static_random_device())
936+
*/
937+
std::vector<cytnx::UniTensor> Rsvd(const cytnx::UniTensor &Tin, cytnx_uint64 keepdim,
938+
const std::vector<cytnx_uint64> min_blockdim,
939+
double err = 0., bool is_U = true, bool is_vT = true,
940+
unsigned int return_err = 0, cytnx_uint64 mindim = 1,
941+
cytnx_uint64 oversampling_summand = 10,
942+
double oversampling_factor = 1.,
943+
cytnx_uint64 power_iteration = 0,
944+
unsigned int seed = random::__static_random_device());
945+
911946
std::vector<cytnx::UniTensor> Hosvd(
912947
const cytnx::UniTensor &Tin, const std::vector<cytnx_uint64> &mode,
913948
const bool &is_core = true, const bool &is_Ls = false,

pybind/linalg_py.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void linalg_binding(py::module &m) {
8787
py::arg("power_iteration") = 2, py::arg("seed") = -1);
8888

8989
m_linalg.def(
90-
"Rsvd",
90+
"Rsvd", // for Tensor
9191
[](const Tensor &Tin, cytnx_uint64 keepdim, double err, bool is_U, bool is_vT,
9292
unsigned int return_err, cytnx_uint64 mindim, cytnx_uint64 oversampling_summand,
9393
double oversampling_factor, cytnx_uint64 power_iteration, int64_t seed) {
@@ -103,7 +103,7 @@ void linalg_binding(py::module &m) {
103103
py::arg("oversampling_summand") = 20, py::arg("oversampling_factor") = 1.,
104104
py::arg("power_iteration") = 0, py::arg("seed") = -1);
105105
m_linalg.def(
106-
"Rsvd",
106+
"Rsvd", // for UniTensor, without min_blockdim
107107
[](const cytnx::UniTensor &Tin, cytnx_uint64 keepdim, double err, bool is_U, bool is_vT,
108108
unsigned int return_err, cytnx_uint64 mindim, cytnx_uint64 oversampling_summand,
109109
double oversampling_factor, cytnx_uint64 power_iteration, int64_t seed) {
@@ -118,6 +118,23 @@ void linalg_binding(py::module &m) {
118118
py::arg("is_vT") = true, py::arg("return_err") = (unsigned int)(0), py::arg("mindim") = 1,
119119
py::arg("oversampling_summand") = 20, py::arg("oversampling_factor") = 1.,
120120
py::arg("power_iteration") = 0, py::arg("seed") = -1);
121+
m_linalg.def(
122+
"Rsvd", // for UniTensor, with min_blockdim
123+
[](const cytnx::UniTensor &Tin, cytnx_uint64 keepdim,
124+
const std::vector<cytnx_uint64> min_blockdim, double err, bool is_U, bool is_vT,
125+
unsigned int return_err, cytnx_uint64 mindim, cytnx_uint64 oversampling_summand,
126+
double oversampling_factor, cytnx_uint64 power_iteration, int64_t seed) {
127+
if (seed == -1) {
128+
// If user doesn't specify seed argument
129+
seed = cytnx::random::__static_random_device();
130+
}
131+
return cytnx::linalg::Rsvd(Tin, keepdim, min_blockdim, err, is_U, is_vT, return_err, mindim,
132+
oversampling_summand, oversampling_factor, power_iteration, seed);
133+
},
134+
py::arg("Tin"), py::arg("keepdim"), py::arg("min_blockdim"), py::arg("err") = double(0),
135+
py::arg("is_U") = true, py::arg("is_vT") = true, py::arg("return_err") = (unsigned int)(0),
136+
py::arg("mindim") = 1, py::arg("oversampling_summand") = 20,
137+
py::arg("oversampling_factor") = 1., py::arg("power_iteration") = 0, py::arg("seed") = -1);
121138

122139
m_linalg.def(
123140
"Gesvd_truncate",

src/linalg/Gesvd_truncate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ namespace cytnx {
211211
} // if tag
212212

213213
if (return_err) outCyT.back().Init(outT.back(), false, 0);
214-
} // Gesvd_truncate_Dense_UT_internal no minblockdim
214+
} // Gesvd_truncate_Dense_UT_internal
215215

216216
void Gesvd_truncate_Block_UT_internal(std::vector<UniTensor> &outCyT,
217217
const cytnx::UniTensor &Tin,

0 commit comments

Comments
 (0)