Skip to content

Commit a8ee89c

Browse files
committed
added Rsvd_notruncate for symmetric and fermionic UniTensors
1 parent a40726f commit a8ee89c

6 files changed

Lines changed: 1246 additions & 153 deletions

File tree

include/linalg.hpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -708,19 +708,30 @@ namespace cytnx {
708708
subspace, and the smallest singular values are less reliable. Use truncation with oversampling.
709709
@param[in] is_U if \em true, the left-unitary UniTensor U (isometry) is returned.
710710
@param[in] is_vT if \em true, the right-unitary UniTensor vT (isometry) is returned.
711+
@param[in] mindim at least this amount of singular values are kept in each block.
712+
@param[in] oversampling_summand the randomized SVD computes [(1 + oversampling_factor) *
713+
keepdim*d/D + oversampling_summand] singular values in each block before further truncating,
714+
where d is the block dimension and D the full tensor dimension (each being the minimum of row
715+
and column dimension of the block or tensor respectively).
716+
@param[in] oversampling_factor see oversampling_summand
711717
@param[in] power_iteration number of iterations for the power method: Y = (A *
712718
Adag)^power_iteration * A * Tin
713719
@param[in] seed the seed for the random generator. [Default] Using device entropy.
714720
@see Rand_isometry(const Tensor &Tin, const cytnx_uint64 &keepdim, const cytnx_uint64
715721
&power_iteration, const unsigned int &seed)
716722
@see Rsvd()
717723
@see Gesvd()
724+
@note At least one singular value per symmetry sector is always kept.
725+
@note More than keepdim singular values might be returned (depending on mindim,
726+
oversampling_summand, oversampling_factor, and symmetry sector sizes).
718727
@warning No truncation of the singular values is performed, and the smaller ones will be
719728
inaccurate. Use Rsvd() for a truncated version that drops small singular values.
720729
*/
721730
std::vector<cytnx::UniTensor> Rsvd_notruncate(
722731
const cytnx::UniTensor &Tin, cytnx_uint64 keepdim, bool is_U = true, bool is_vT = true,
723-
cytnx_uint64 power_iteration = 2, unsigned int seed = random::__static_random_device());
732+
cytnx_uint64 mindim = 1, cytnx_uint64 oversampling_summand = 0,
733+
double oversampling_factor = 0., cytnx_uint64 power_iteration = 2,
734+
unsigned int seed = random::__static_random_device());
724735

725736
/**
726737
@brief Perform Singular-Value decomposition on a UniTensor using ?gesvd method.
@@ -865,10 +876,12 @@ namespace cytnx {
865876
the largest error will be pushed back to the vector (The smallest singular value in the return
866877
singular values matrix \f$ S \f$.) If \p return_err is > 1, then the full list of truncated
867878
singular values will be returned.
868-
@param[in] oversampling_summand the randomized SVD computes (1 + oversampling_fact) * keepdim +
869-
oversampling_summand before further truncating to keepdim
870-
@param[in] oversampling_fact the randomized SVD computes (1 + oversampling_fact) * keepdim +
871-
oversampling_summand before further truncating to keepdim
879+
@param[in] mindim at least this amount of singular values are kept in each block.
880+
@param[in] oversampling_summand the randomized SVD computes [(1 + oversampling_factor) *
881+
keepdim*d/D + oversampling_summand] singular values in each block before further truncating,
882+
where d is the block dimension and D the full tensor dimension (each being the minimum of row
883+
and column dimension of the block or tensor respectively).
884+
@param[in] oversampling_factor see oversampling_summand
872885
@param[in] power_iteration number of iterations for the power method: Y = (A *
873886
Adag)^power_iteration * A * Tin
874887
@param[in] seed the seed for the random generator. [Default] Using device entropy.
@@ -1887,10 +1900,12 @@ namespace cytnx {
18871900
the largest error will be pushed back to the vector (The smallest singular value in the return
18881901
singular values matrix \f$ S \f$.) If \p return_err is a \em positive int, then the
18891902
full list of truncated singular values will be returned.
1890-
@param[in] oversampling_summand the randomized SVD computes (1 + oversampling_fact) * keepdim +
1891-
oversampling_summand before further truncating to keepdim
1892-
@param[in] oversampling_fact the randomized SVD computes (1 + oversampling_fact) * keepdim +
1893-
oversampling_summand before further truncating to keepdim
1903+
@param[in] mindim at least this amount of singular values are kept in each block.
1904+
@param[in] oversampling_summand the randomized SVD computes [(1 + oversampling_factor) *
1905+
keepdim*d/D + oversampling_summand] singular values in each block before further truncating,
1906+
where d is the block dimension and D the full tensor dimension (each being the minimum of row
1907+
and column dimension of the block or tensor respectively).
1908+
@param[in] oversampling_factor see oversampling_summand
18941909
@param[in] power_iteration number of iterations for the power method: Y = (A *
18951910
Adag)^power_iteration * A * Tin
18961911
@param[in] seed the seed for the random generator. [Default] Using device entropy.

pybind/linalg_py.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,17 @@ void linalg_binding(py::module &m) {
7373
m_linalg.def(
7474
"Rsvd_notruncate",
7575
[](const cytnx::UniTensor &Tin, cytnx_uint64 keepdim, bool is_U, bool is_vT,
76+
cytnx_uint64 mindim, cytnx_uint64 oversampling_summand, double oversampling_factor,
7677
cytnx_uint64 power_iteration, int64_t seed) {
7778
if (seed == -1) {
7879
// If user doesn't specify seed argument
7980
seed = cytnx::random::__static_random_device();
8081
}
81-
return cytnx::linalg::Rsvd_notruncate(Tin, keepdim, is_U, is_vT, power_iteration, seed);
82+
return cytnx::linalg::Rsvd_notruncate(Tin, keepdim, is_U, is_vT, mindim, oversampling_summand,
83+
oversampling_factor, power_iteration, seed);
8284
},
8385
py::arg("Tin"), py::arg("keepdim"), py::arg("is_U") = true, py::arg("is_vT") = true,
86+
py::arg("mindim") = 1, py::arg("oversampling_summand") = 0, py::arg("oversampling_factor") = 0.,
8487
py::arg("power_iteration") = 2, py::arg("seed") = -1);
8588

8689
m_linalg.def(

src/algo/Vstack.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ namespace cytnx {
1616
typedef Accessor ac;
1717

1818
Tensor Vstack(const std::vector<Tensor> &In_tensors) {
19-
Tensor out;
20-
2119
std::vector<Tensor> _Ins;
2220

2321
// check:
@@ -82,7 +80,7 @@ namespace cytnx {
8280
}
8381

8482
// allocate out!
85-
out = zeros({Dcomb, Dshare}, dtype_id, device_id);
83+
Tensor out = zeros({Dcomb, Dshare}, dtype_id, device_id);
8684

8785
std::vector<void *> rawPtr(In_tensors.size());
8886
for (int i = 0; i < _Ins.size(); i++) {

src/linalg/Rsvd.cpp

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,151 @@ namespace cytnx {
240240

241241
if (return_err) outCyT.back().Init(outT.back(), false, 0);
242242
} // Rsvd_Dense_UT_internal
243+
244+
void Rsvd_Block_UT_internal(std::vector<UniTensor> &outCyT, const cytnx::UniTensor &Tin,
245+
const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
246+
const bool &is_vT, const unsigned int &return_err,
247+
const cytnx_uint64 &mindim) {
248+
cytnx_uint64 keep_dim = keepdim;
249+
250+
outCyT = linalg::Gesvd(Tin, is_U, is_vT);
251+
252+
// process truncation:
253+
// 1) concate all S vals from all blk
254+
Tensor Sall = outCyT[0].get_block_(0);
255+
for (int i = 1; i < outCyT[0].Nblocks(); i++) {
256+
Sall = algo::Concatenate(Sall, outCyT[0].get_block_(i));
257+
}
258+
Sall = algo::Sort(Sall); // all singular values, starting from the smallest
259+
260+
// 2) get the minimum S value based on the args input.
261+
Scalar Smin;
262+
cytnx_uint64 smidx;
263+
cytnx_uint64 Sshape = Sall.shape()[0];
264+
if (keep_dim < Sshape) {
265+
smidx = Sshape - keep_dim;
266+
Smin = Sall.storage()(smidx);
267+
} else {
268+
keep_dim = Sshape;
269+
smidx = 0;
270+
Smin = Sall.storage()(0);
271+
}
272+
while ((Smin < err) and (keep_dim > (mindim < 1 ? 1 : mindim))) {
273+
// at least one singular value is always kept!
274+
keep_dim--;
275+
// if (keep_dim == 0) break;
276+
smidx++;
277+
Smin = Sall.storage()(smidx);
278+
}
279+
280+
// traversal each block and truncate!
281+
UniTensor &S = outCyT[0];
282+
std::vector<cytnx_uint64> new_dims; // keep_dims for each block!
283+
std::vector<cytnx_int64> keep_dims;
284+
keep_dims.reserve(S.Nblocks());
285+
std::vector<cytnx_int64> new_qid;
286+
new_qid.reserve(S.Nblocks());
287+
288+
std::vector<std::vector<cytnx_uint64>>
289+
new_itoi; // assume S block is in same order as qnum:
290+
std::vector<cytnx_uint64> to_be_removed;
291+
292+
cytnx_uint64 tot_dim = 0;
293+
cytnx_uint64 cnt = 0;
294+
for (int b = 0; b < S.Nblocks(); b++) {
295+
Storage stmp = S.get_block_(b).storage();
296+
cytnx_int64 kdim = 0;
297+
for (int i = stmp.size(); i > 0; i--) {
298+
if (stmp(i - 1) >= Smin) {
299+
kdim = i;
300+
break;
301+
}
302+
}
303+
keep_dims.push_back(kdim);
304+
if (kdim == 0) {
305+
to_be_removed.push_back(b);
306+
new_qid.push_back(-1);
307+
308+
} else {
309+
new_qid.push_back(new_dims.size());
310+
new_itoi.push_back({new_dims.size(), new_dims.size()});
311+
new_dims.push_back(kdim);
312+
tot_dim += kdim;
313+
if (kdim != S.get_blocks_()[b].shape()[0])
314+
S.get_blocks_()[b] = S.get_blocks_()[b].get({Accessor::range(0, kdim)});
315+
}
316+
}
317+
318+
// remove:
319+
// vec_erase_(S.get_itoi(),to_be_removed);
320+
S.get_itoi() = new_itoi;
321+
vec_erase_(S.get_blocks_(), to_be_removed);
322+
vec_erase_(S.bonds()[0].qnums(), to_be_removed);
323+
S.bonds()[0]._impl->_degs = new_dims;
324+
S.bonds()[0]._impl->_dim = tot_dim;
325+
S.bonds()[1] = S.bonds()[0].redirect();
326+
327+
int t = 1;
328+
if (is_U) {
329+
UniTensor &U = outCyT[t];
330+
to_be_removed.clear();
331+
U.bonds().back() = S.bonds()[1].clone();
332+
std::vector<Accessor> acs(U.rank());
333+
for (int i = 0; i < U.rowrank(); i++) acs[i] = Accessor::all();
334+
335+
for (int b = 0; b < U.Nblocks(); b++) {
336+
if (keep_dims[U.get_qindices(b).back()] == 0)
337+
to_be_removed.push_back(b);
338+
else {
339+
/// process blocks:
340+
if (keep_dims[U.get_qindices(b).back()] != U.get_blocks_()[b].shape().back()) {
341+
acs.back() = Accessor::range(0, keep_dims[U.get_qindices(b).back()]);
342+
U.get_blocks_()[b] = U.get_blocks_()[b].get(acs);
343+
}
344+
345+
// change to new qindices:
346+
U.get_qindices(b).back() = new_qid[U.get_qindices(b).back()];
347+
}
348+
}
349+
vec_erase_(U.get_itoi(), to_be_removed);
350+
vec_erase_(U.get_blocks_(), to_be_removed);
351+
352+
t++;
353+
}
354+
355+
if (is_vT) {
356+
UniTensor &vT = outCyT[t];
357+
to_be_removed.clear();
358+
vT.bonds().front() = S.bonds()[0].clone();
359+
std::vector<Accessor> acs(vT.rank());
360+
for (int i = 1; i < vT.rank(); i++) acs[i] = Accessor::all();
361+
362+
for (int b = 0; b < vT.Nblocks(); b++) {
363+
if (keep_dims[vT.get_qindices(b)[0]] == 0)
364+
to_be_removed.push_back(b);
365+
else {
366+
/// process blocks:
367+
if (keep_dims[vT.get_qindices(b)[0]] != vT.get_blocks_()[b].shape()[0]) {
368+
acs[0] = Accessor::range(0, keep_dims[vT.get_qindices(b)[0]]);
369+
vT.get_blocks_()[b] = vT.get_blocks_()[b].get(acs);
370+
}
371+
// change to new qindices:
372+
vT.get_qindices(b)[0] = new_qid[vT.get_qindices(b)[0]];
373+
}
374+
}
375+
vec_erase_(vT.get_itoi(), to_be_removed);
376+
vec_erase_(vT.get_blocks_(), to_be_removed);
377+
t++;
378+
}
379+
380+
// handle return_err!
381+
if (return_err == 1) {
382+
outCyT.push_back(UniTensor(Tensor({1}, Smin.dtype())));
383+
outCyT.back().get_block_().storage().at(0) = Smin;
384+
} else if (return_err) {
385+
outCyT.push_back(UniTensor(Sall.get({Accessor::tilend(smidx)})));
386+
}
387+
} // Rsvd_Block_UT_internal
243388
} // unnamed namespace
244389

245390
std::vector<cytnx::UniTensor> Rsvd(const cytnx::UniTensor &Tin, cytnx_uint64 keepdim,
@@ -258,18 +403,16 @@ namespace cytnx {
258403
"\n");
259404

260405
// check input arguments
261-
// cytnx_error_msg(mindim < 0, "[ERROR][Rsvd] mindim must be >=1%s", "\n");
406+
cytnx_error_msg(mindim < 0, "[ERROR][Rsvd] mindim must be >=1%s", "\n");
262407
cytnx_error_msg(keepdim < 1, "[ERROR][Rsvd] keepdim must be >=1%s", "\n");
263-
// cytnx_error_msg(return_err < 0, "[ERROR][Rsvd] return_err cannot be negative%s",
264-
// "\n");
408+
cytnx_error_msg(return_err < 0, "[ERROR][Rsvd] return_err cannot be negative%s", "\n");
265409

266410
std::vector<UniTensor> outCyT;
267411
if (Tin.uten_type() == UTenType.Dense) {
268412
Rsvd_Dense_UT_internal(outCyT, Tin, keepdim, err, is_U, is_vT, return_err, mindim,
269413
oversampling_summand, oversampling_factor, power_iteration, seed);
270-
// } else if (Tin.uten_type() == UTenType.Block) {
271-
// _Rsvd_Block_UT(outCyT, Tin, keepdim, err, is_U, is_vT,
272-
// return_err, mindim);
414+
} else if (Tin.uten_type() == UTenType.Block) {
415+
Rsvd_Block_UT_internal(outCyT, Tin, keepdim, err, is_U, is_vT, return_err, mindim);
273416
} else {
274417
cytnx_error_msg(true, "[ERROR][Rsvd] only Dense UniTensors are supported.%s", "\n");
275418
}

0 commit comments

Comments
 (0)