@@ -240,6 +240,151 @@ namespace cytnx {
240240
241241 if (return_err) outCyT.back ().Init (outT.back (), false , 0 );
242242 } // Rsvd_Dense_UT_internal
243+
244+ void Rsvd_Block_UT_internal (std::vector<UniTensor> &outCyT, const cytnx::UniTensor &Tin,
245+ const cytnx_uint64 &keepdim, const double &err, const bool &is_U,
246+ const bool &is_vT, const unsigned int &return_err,
247+ const cytnx_uint64 &mindim) {
248+ cytnx_uint64 keep_dim = keepdim;
249+
250+ outCyT = linalg::Gesvd (Tin, is_U, is_vT);
251+
252+ // process truncation:
253+ // 1) concate all S vals from all blk
254+ Tensor Sall = outCyT[0 ].get_block_ (0 );
255+ for (int i = 1 ; i < outCyT[0 ].Nblocks (); i++) {
256+ Sall = algo::Concatenate (Sall, outCyT[0 ].get_block_ (i));
257+ }
258+ Sall = algo::Sort (Sall); // all singular values, starting from the smallest
259+
260+ // 2) get the minimum S value based on the args input.
261+ Scalar Smin;
262+ cytnx_uint64 smidx;
263+ cytnx_uint64 Sshape = Sall.shape ()[0 ];
264+ if (keep_dim < Sshape) {
265+ smidx = Sshape - keep_dim;
266+ Smin = Sall.storage ()(smidx);
267+ } else {
268+ keep_dim = Sshape;
269+ smidx = 0 ;
270+ Smin = Sall.storage ()(0 );
271+ }
272+ while ((Smin < err) and (keep_dim > (mindim < 1 ? 1 : mindim))) {
273+ // at least one singular value is always kept!
274+ keep_dim--;
275+ // if (keep_dim == 0) break;
276+ smidx++;
277+ Smin = Sall.storage ()(smidx);
278+ }
279+
280+ // traversal each block and truncate!
281+ UniTensor &S = outCyT[0 ];
282+ std::vector<cytnx_uint64> new_dims; // keep_dims for each block!
283+ std::vector<cytnx_int64> keep_dims;
284+ keep_dims.reserve (S.Nblocks ());
285+ std::vector<cytnx_int64> new_qid;
286+ new_qid.reserve (S.Nblocks ());
287+
288+ std::vector<std::vector<cytnx_uint64>>
289+ new_itoi; // assume S block is in same order as qnum:
290+ std::vector<cytnx_uint64> to_be_removed;
291+
292+ cytnx_uint64 tot_dim = 0 ;
293+ cytnx_uint64 cnt = 0 ;
294+ for (int b = 0 ; b < S.Nblocks (); b++) {
295+ Storage stmp = S.get_block_ (b).storage ();
296+ cytnx_int64 kdim = 0 ;
297+ for (int i = stmp.size (); i > 0 ; i--) {
298+ if (stmp (i - 1 ) >= Smin) {
299+ kdim = i;
300+ break ;
301+ }
302+ }
303+ keep_dims.push_back (kdim);
304+ if (kdim == 0 ) {
305+ to_be_removed.push_back (b);
306+ new_qid.push_back (-1 );
307+
308+ } else {
309+ new_qid.push_back (new_dims.size ());
310+ new_itoi.push_back ({new_dims.size (), new_dims.size ()});
311+ new_dims.push_back (kdim);
312+ tot_dim += kdim;
313+ if (kdim != S.get_blocks_ ()[b].shape ()[0 ])
314+ S.get_blocks_ ()[b] = S.get_blocks_ ()[b].get ({Accessor::range (0 , kdim)});
315+ }
316+ }
317+
318+ // remove:
319+ // vec_erase_(S.get_itoi(),to_be_removed);
320+ S.get_itoi () = new_itoi;
321+ vec_erase_ (S.get_blocks_ (), to_be_removed);
322+ vec_erase_ (S.bonds ()[0 ].qnums (), to_be_removed);
323+ S.bonds ()[0 ]._impl ->_degs = new_dims;
324+ S.bonds ()[0 ]._impl ->_dim = tot_dim;
325+ S.bonds ()[1 ] = S.bonds ()[0 ].redirect ();
326+
327+ int t = 1 ;
328+ if (is_U) {
329+ UniTensor &U = outCyT[t];
330+ to_be_removed.clear ();
331+ U.bonds ().back () = S.bonds ()[1 ].clone ();
332+ std::vector<Accessor> acs (U.rank ());
333+ for (int i = 0 ; i < U.rowrank (); i++) acs[i] = Accessor::all ();
334+
335+ for (int b = 0 ; b < U.Nblocks (); b++) {
336+ if (keep_dims[U.get_qindices (b).back ()] == 0 )
337+ to_be_removed.push_back (b);
338+ else {
339+ // / process blocks:
340+ if (keep_dims[U.get_qindices (b).back ()] != U.get_blocks_ ()[b].shape ().back ()) {
341+ acs.back () = Accessor::range (0 , keep_dims[U.get_qindices (b).back ()]);
342+ U.get_blocks_ ()[b] = U.get_blocks_ ()[b].get (acs);
343+ }
344+
345+ // change to new qindices:
346+ U.get_qindices (b).back () = new_qid[U.get_qindices (b).back ()];
347+ }
348+ }
349+ vec_erase_ (U.get_itoi (), to_be_removed);
350+ vec_erase_ (U.get_blocks_ (), to_be_removed);
351+
352+ t++;
353+ }
354+
355+ if (is_vT) {
356+ UniTensor &vT = outCyT[t];
357+ to_be_removed.clear ();
358+ vT.bonds ().front () = S.bonds ()[0 ].clone ();
359+ std::vector<Accessor> acs (vT.rank ());
360+ for (int i = 1 ; i < vT.rank (); i++) acs[i] = Accessor::all ();
361+
362+ for (int b = 0 ; b < vT.Nblocks (); b++) {
363+ if (keep_dims[vT.get_qindices (b)[0 ]] == 0 )
364+ to_be_removed.push_back (b);
365+ else {
366+ // / process blocks:
367+ if (keep_dims[vT.get_qindices (b)[0 ]] != vT.get_blocks_ ()[b].shape ()[0 ]) {
368+ acs[0 ] = Accessor::range (0 , keep_dims[vT.get_qindices (b)[0 ]]);
369+ vT.get_blocks_ ()[b] = vT.get_blocks_ ()[b].get (acs);
370+ }
371+ // change to new qindices:
372+ vT.get_qindices (b)[0 ] = new_qid[vT.get_qindices (b)[0 ]];
373+ }
374+ }
375+ vec_erase_ (vT.get_itoi (), to_be_removed);
376+ vec_erase_ (vT.get_blocks_ (), to_be_removed);
377+ t++;
378+ }
379+
380+ // handle return_err!
381+ if (return_err == 1 ) {
382+ outCyT.push_back (UniTensor (Tensor ({1 }, Smin.dtype ())));
383+ outCyT.back ().get_block_ ().storage ().at (0 ) = Smin;
384+ } else if (return_err) {
385+ outCyT.push_back (UniTensor (Sall.get ({Accessor::tilend (smidx)})));
386+ }
387+ } // Rsvd_Block_UT_internal
243388 } // unnamed namespace
244389
245390 std::vector<cytnx::UniTensor> Rsvd (const cytnx::UniTensor &Tin, cytnx_uint64 keepdim,
@@ -258,18 +403,16 @@ namespace cytnx {
258403 " \n " );
259404
260405 // check input arguments
261- // cytnx_error_msg(mindim < 0, "[ERROR][Rsvd] mindim must be >=1%s", "\n");
406+ cytnx_error_msg (mindim < 0 , " [ERROR][Rsvd] mindim must be >=1%s" , " \n " );
262407 cytnx_error_msg (keepdim < 1 , " [ERROR][Rsvd] keepdim must be >=1%s" , " \n " );
263- // cytnx_error_msg(return_err < 0, "[ERROR][Rsvd] return_err cannot be negative%s",
264- // "\n");
408+ cytnx_error_msg (return_err < 0 , " [ERROR][Rsvd] return_err cannot be negative%s" , " \n " );
265409
266410 std::vector<UniTensor> outCyT;
267411 if (Tin.uten_type () == UTenType.Dense ) {
268412 Rsvd_Dense_UT_internal (outCyT, Tin, keepdim, err, is_U, is_vT, return_err, mindim,
269413 oversampling_summand, oversampling_factor, power_iteration, seed);
270- // } else if (Tin.uten_type() == UTenType.Block) {
271- // _Rsvd_Block_UT(outCyT, Tin, keepdim, err, is_U, is_vT,
272- // return_err, mindim);
414+ } else if (Tin.uten_type () == UTenType.Block ) {
415+ Rsvd_Block_UT_internal (outCyT, Tin, keepdim, err, is_U, is_vT, return_err, mindim);
273416 } else {
274417 cytnx_error_msg (true , " [ERROR][Rsvd] only Dense UniTensors are supported.%s" , " \n " );
275418 }
0 commit comments